Skip to content

Commit d8af17f

Browse files
authored
Fix offline agent test (#1025)
1 parent 1d3c7da commit d8af17f

1 file changed

Lines changed: 28 additions & 20 deletions

File tree

  • src/ReinforcementLearningCore/test/policies

src/ReinforcementLearningCore/test/policies/agent.jl

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ import ReinforcementLearningCore.SRT
66
a_1 = Agent(
77
RandomPolicy(),
88
Trajectory(
9-
CircularArraySARTSTraces(; capacity = 1_000),
9+
CircularArraySARTSTraces(; capacity=1_000),
1010
DummySampler(),
1111
),
1212
)
1313
a_2 = Agent(
1414
RandomPolicy(),
1515
Trajectory(
16-
CircularArraySARTSTraces(; capacity = 1_000),
16+
CircularArraySARTSTraces(; capacity=1_000),
1717
BatchSampler(1),
1818
InsertSampleRatioController(),
1919
),
@@ -26,25 +26,25 @@ import ReinforcementLearningCore.SRT
2626
env = RandomWalk1D()
2727
push!(agent, PreEpisodeStage(), env)
2828
action = RLBase.plan!(agent, env)
29-
@test action in (1,2)
30-
@test length(agent.trajectory.container) == 0
29+
@test action in (1, 2)
30+
@test length(agent.trajectory.container) == 0
3131
push!(agent, PostActStage(), env, action)
3232
push!(agent, PreActStage(), env)
33-
@test RLBase.plan!(agent, env) in (1,2)
33+
@test RLBase.plan!(agent, env) in (1, 2)
3434
@test length(agent.trajectory.container) == 1
3535

3636
#The following tests checks args / kwargs passed to policy cause an error
3737
@test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, 1)
38-
@test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, fake_kwarg = 1)
38+
@test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, fake_kwarg=1)
3939
end
4040
end
4141
end
4242
@testset "OfflineAgent" begin
4343
env = RandomWalk1D()
4444
a_1 = OfflineAgent(
45-
policy = RandomPolicy(),
46-
trajectory = Trajectory(
47-
CircularArraySARTSTraces(; capacity = 1_000),
45+
policy=RandomPolicy(),
46+
trajectory=Trajectory(
47+
CircularArraySARTSTraces(; capacity=1_000),
4848
DummySampler(),
4949
),
5050
)
@@ -53,27 +53,35 @@ import ReinforcementLearningCore.SRT
5353
@test isempty(a_1.trajectory.container)
5454

5555
trajectory = Trajectory(
56-
CircularArraySARTSTraces(; capacity = 1_000),
57-
DummySampler(),
58-
)
59-
56+
CircularArraySARTSTraces(; capacity=1_000),
57+
DummySampler(),
58+
)
59+
6060
a_2 = OfflineAgent(
61-
policy = RandomPolicy(),
62-
trajectory = trajectory,
63-
offline_behavior = OfflineBehavior(
61+
policy=RandomPolicy(),
62+
trajectory=trajectory,
63+
offline_behavior=OfflineBehavior(
6464
Agent(RandomPolicy(), trajectory),
65-
steps = 5,
65+
steps=5,
6666
)
6767
)
6868
push!(a_2, PreExperimentStage(), env)
69-
@test length(a_2.trajectory.container) == 5
69+
# We'll have 1 extra element where terminal is true
70+
# if the environment was terminated mid-episode and restarted!
71+
ix = findfirst(x -> x.terminal, map(identity, a_2.trajectory.container))
72+
len = length(a_2.trajectory.container)
73+
max = isnothing(ix) || ix == len ? 5 : 6
74+
@test len == max
7075

7176
for agent in [a_1, a_2]
7277
action = RLBase.plan!(agent, env)
73-
@test action in (1,2)
78+
@test action in (1, 2)
7479
for stage in [PreEpisodeStage(), PreActStage(), PostActStage(), PostEpisodeStage()]
7580
push!(agent, stage, env)
76-
@test length(agent.trajectory.container) in (0,5)
81+
ix = findfirst(x -> x.terminal, map(identity, agent.trajectory.container))
82+
len = length(agent.trajectory.container)
83+
max = isnothing(ix) || ix == len ? 5 : 6
84+
@test len in (0, max)
7785
end
7886
end
7987
end

0 commit comments

Comments
 (0)