@@ -6,14 +6,14 @@ import ReinforcementLearningCore.SRT
66 a_1 = Agent (
77 RandomPolicy (),
88 Trajectory (
9- CircularArraySARTSTraces (; capacity = 1_000 ),
9+ CircularArraySARTSTraces (; capacity= 1_000 ),
1010 DummySampler (),
1111 ),
1212 )
1313 a_2 = Agent (
1414 RandomPolicy (),
1515 Trajectory (
16- CircularArraySARTSTraces (; capacity = 1_000 ),
16+ CircularArraySARTSTraces (; capacity= 1_000 ),
1717 BatchSampler (1 ),
1818 InsertSampleRatioController (),
1919 ),
@@ -26,25 +26,25 @@ import ReinforcementLearningCore.SRT
2626 env = RandomWalk1D ()
2727 push! (agent, PreEpisodeStage (), env)
2828 action = RLBase. plan! (agent, env)
29- @test action in (1 ,2 )
30- @test length (agent. trajectory. container) == 0
29+ @test action in (1 , 2 )
30+ @test length (agent. trajectory. container) == 0
3131 push! (agent, PostActStage (), env, action)
3232 push! (agent, PreActStage (), env)
33- @test RLBase. plan! (agent, env) in (1 ,2 )
33+ @test RLBase. plan! (agent, env) in (1 , 2 )
3434 @test length (agent. trajectory. container) == 1
3535
3636 # The following tests checks args / kwargs passed to policy cause an error
3737 @test_throws " MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase. plan! (agent, env, 1 )
38- @test_throws " MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase. plan! (agent, env, fake_kwarg = 1 )
38+ @test_throws " MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase. plan! (agent, env, fake_kwarg= 1 )
3939 end
4040 end
4141 end
4242 @testset " OfflineAgent" begin
4343 env = RandomWalk1D ()
4444 a_1 = OfflineAgent (
45- policy = RandomPolicy (),
46- trajectory = Trajectory (
47- CircularArraySARTSTraces (; capacity = 1_000 ),
45+ policy= RandomPolicy (),
46+ trajectory= Trajectory (
47+ CircularArraySARTSTraces (; capacity= 1_000 ),
4848 DummySampler (),
4949 ),
5050 )
@@ -53,27 +53,35 @@ import ReinforcementLearningCore.SRT
5353 @test isempty (a_1. trajectory. container)
5454
5555 trajectory = Trajectory (
56- CircularArraySARTSTraces (; capacity = 1_000 ),
57- DummySampler (),
58- )
59-
56+ CircularArraySARTSTraces (; capacity= 1_000 ),
57+ DummySampler (),
58+ )
59+
6060 a_2 = OfflineAgent (
61- policy = RandomPolicy (),
62- trajectory = trajectory,
63- offline_behavior = OfflineBehavior (
61+ policy= RandomPolicy (),
62+ trajectory= trajectory,
63+ offline_behavior= OfflineBehavior (
6464 Agent (RandomPolicy (), trajectory),
65- steps = 5 ,
65+ steps= 5 ,
6666 )
6767 )
6868 push! (a_2, PreExperimentStage (), env)
69- @test length (a_2. trajectory. container) == 5
69+ # We'll have 1 extra element where terminal is true
70+ # if the environment was terminated mid-episode and restarted!
71+ ix = findfirst (x -> x. terminal, map (identity, a_2. trajectory. container))
72+ len = length (a_2. trajectory. container)
73+ max = isnothing (ix) || ix == len ? 5 : 6
74+ @test len == max
7075
7176 for agent in [a_1, a_2]
7277 action = RLBase. plan! (agent, env)
73- @test action in (1 ,2 )
78+ @test action in (1 , 2 )
7479 for stage in [PreEpisodeStage (), PreActStage (), PostActStage (), PostEpisodeStage ()]
7580 push! (agent, stage, env)
76- @test length (agent. trajectory. container) in (0 ,5 )
81+ ix = findfirst (x -> x. terminal, map (identity, agent. trajectory. container))
82+ len = length (agent. trajectory. container)
83+ max = isnothing (ix) || ix == len ? 5 : 6
84+ @test len in (0 , max)
7785 end
7886 end
7987 end
0 commit comments