Skip to content

Commit f8d5eb7

Browse files
Fix abstract_learner for multiplayer games (#1054)
* Fix abstract_learner for multiplayer games * fix test * drop excess mutability * fix * fix type instability * fix type instability * type stability * fix type instability in stop condition * add missing default method * drop excess code --------- Co-authored-by: Jeremiah Lewis <--get>
1 parent 93a13d3 commit f8d5eb7

6 files changed

Lines changed: 37 additions & 13 deletions

File tree

src/ReinforcementLearningCore/src/core/stop_conditions.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ abstract type AbstractStopCondition end
1515
1616
The result of `stop_conditions` is reduced by `reducer`. The default `reducer` is the `any` function, which means that the condition is true when any one of the `stop_conditions...` is true. Can be replaced by any function returning a boolean. For example `reducer = x->sum(x) >= 2` will require at least two of the conditions to be true.
1717
"""
18-
struct ComposedStopCondition{S,T} <: AbstractStopCondition
18+
struct ComposedStopCondition{S,reducer} <: AbstractStopCondition
1919
stop_conditions::S
20-
reducer::T
20+
reducer
2121
function ComposedStopCondition(stop_conditions...; reducer = any)
22-
new{typeof(stop_conditions),typeof(reducer)}(stop_conditions, reducer)
22+
new{typeof(stop_conditions),reducer}(stop_conditions, reducer)
2323
end
2424
end
2525

26-
function check!(s::ComposedStopCondition, args...)
27-
s.reducer(check!(sc, args...) for sc in s.stop_conditions)
26+
function check!(s::ComposedStopCondition{S,R}, policy::P, env::E) where {S,R,P<:AbstractPolicy,E<:AbstractEnv}
27+
s.reducer(check!(sc, policy, env) for sc in s.stop_conditions)
2828
end
2929

3030
#####
@@ -58,12 +58,12 @@ function _stop_after_step(s::StopAfterNSteps)
5858
res
5959
end
6060

61-
function check!(s::StopAfterNSteps, args...)
61+
function check!(s::StopAfterNSteps, agent, env)
6262
ProgressMeter.next!(s.progress)
6363
_stop_after_step(s)
6464
end
6565

66-
check!(s::StopAfterNSteps{Nothing}, args...) = _stop_after_step(s)
66+
check!(s::StopAfterNSteps{Nothing}, agent, env) = _stop_after_step(s)
6767

6868
#####
6969
# StopAfterNEpisodes

src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,21 @@ function forward(learner::L, env::E) where {L <: AbstractLearner, E <: AbstractE
1111
env |> state |> (x -> forward(learner, x))
1212
end
1313

14+
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
15+
function forward(learner::L, env::E, player::Symbol) where {L <: AbstractLearner, E <: AbstractEnv}
16+
env |> (x -> state(x, player)) |> (x -> forward(learner, x))
17+
end
18+
1419
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end
1520

21+
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::NamedTuple) end
22+
1623
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv)
1724
legal_action_space_ = RLBase.legal_action_space_mask(env)
1825
RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
1926
end
2027

2128
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player::Symbol)
2229
legal_action_space_ = RLBase.legal_action_space_mask(env, player)
23-
return RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
30+
return RLBase.plan!(explorer, forward(learner, env, player), legal_action_space_)
2431
end

src/ReinforcementLearningCore/src/policies/learners/td_learner.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,3 @@ end
9090

9191
# TDLearner{:SARS} is optimized at the PostActStage
9292
RLBase.optimise!(learner::TDLearner{:SARS}, stage::PostActStage, trace::NamedTuple) = RLBase.optimise!(learner, trace)
93-

src/ReinforcementLearningCore/src/policies/q_based_policy.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,19 @@ action of an environment at its current state. It is typically a table or a neur
1010
QBasedPolicy can be queried for an action with `RLBase.plan!`, the explorer will affect the action selection
1111
accordingly.
1212
"""
13-
Base.@kwdef mutable struct QBasedPolicy{L<:TDLearner,E<:AbstractExplorer} <: AbstractPolicy
13+
struct QBasedPolicy{L<:TDLearner,E<:AbstractExplorer} <: AbstractPolicy
1414
"estimate the Q value"
1515
learner::L
1616
"select the action based on Q values calculated by the learner"
1717
explorer::E
18+
19+
function QBasedPolicy(; learner::L, explorer::E) where {L<:TDLearner, E<:AbstractExplorer}
20+
new{L,E}(learner, explorer)
21+
end
22+
23+
function QBasedPolicy(learner::L, explorer::E) where {L<:TDLearner, E<:AbstractExplorer}
24+
new{L,E}(learner, explorer)
25+
end
1826
end
1927

2028
Flux.@layer QBasedPolicy trainable=(learner,)

src/ReinforcementLearningCore/test/core/stop_conditions.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,24 @@ import ReinforcementLearningCore.check!
22

33
@testset "StopAfterNSteps" begin
44
stop_condition = StopAfterNSteps(10)
5-
@test sum([check!(stop_condition) for i in 1:20]) == 11
5+
env = RandomWalk1D()
6+
policy = RandomPolicy(legal_action_space(env))
7+
8+
@test sum([check!(stop_condition, policy, env) for i in 1:20]) == 11
69

710
stop_condition = StopAfterNSteps(10; is_show_progress=false)
8-
@test sum([check!(stop_condition) for i in 1:20]) == 11
11+
@test sum([check!(stop_condition, policy, env) for i in 1:20]) == 11
912
end
1013

1114
@testset "ComposedStopCondition" begin
1215
stop_10 = StopAfterNSteps(10)
1316
stop_3 = StopAfterNSteps(3)
1417

18+
env = RandomWalk1D()
19+
policy = RandomPolicy(legal_action_space(env))
20+
1521
composed_stop = ComposedStopCondition(stop_10, stop_3)
16-
@test sum([check!(composed_stop) for i in 1:20]) == 18
22+
@test sum([check!(composed_stop, policy, env) for i in 1:20]) == 18
1723
end
1824

1925
@testset "StopAfterNEpisodes" begin

src/ReinforcementLearningCore/test/policies/learners/abstract_learner.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@ struct MockLearner <: AbstractLearner end
1515
end
1616

1717
RLBase.state(::MockEnv, ::Observation{Any}, ::DefaultPlayer) = 1
18+
RLBase.state(::MockEnv, ::Observation{Any}, ::Symbol) = 1
1819

1920
env = MockEnv()
2021
learner = MockLearner()
2122

2223
output = RLCore.forward(learner, env)
2324
@test output == Float64[1.0, 2.0]
25+
26+
output = RLCore.forward(learner, env, Symbol(1))
27+
@test output == Float64[1.0, 2.0]
2428
end
2529

2630
@testset "Plan" begin

0 commit comments

Comments
 (0)