Skip to content

Commit 28d3ce5

Browse files
Rework tabular approximator
1 parent 4f96c51 commit 28d3ce5

4 files changed

Lines changed: 44 additions & 43 deletions

File tree

src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,6 @@ end
1414

1515
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end
1616

17-
18-
"""
19-
Approximator(model, optimiser)
20-
21-
Wraps a Flux trainable model and implements the `RLBase.optimise!(::Approximator, ::Gradient)`
22-
interface. See the RLCore documentation for more information on proper usage.
23-
"""
24-
struct Approximator{M,O} <: AbstractLearner
25-
model::M
26-
optimiser_state::O
27-
end
28-
29-
function Approximator(; model, optimiser)
30-
optimiser_state = Flux.setup(optimiser, model)
31-
Approximator(gpu(model), gpu(optimiser_state)) # Pass model to GPU (if available) upon creation
32-
end
33-
34-
Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
35-
36-
@functor Approximator (model,)
37-
3817
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv)
3918
legal_action_space_ = RLBase.legal_action_space_mask(env)
4019
RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
@@ -44,8 +23,3 @@ function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env:
4423
legal_action_space_ = RLBase.legal_action_space_mask(env, player)
4524
return RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
4625
end
47-
48-
forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
49-
50-
RLBase.optimise!(A::Approximator, grad) =
51-
Flux.Optimise.update!(A.optimiser_state, A.model, grad)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Approximator(model, optimiser)
3+
4+
Wraps a Flux trainable model and implements the `RLBase.optimise!(::Approximator, ::Gradient)`
5+
interface. See the RLCore documentation for more information on proper usage.
6+
"""
7+
struct Approximator{M,O} <: AbstractLearner
8+
model::M
9+
optimiser_state::O
10+
end
11+
12+
function Approximator(; model, optimiser, gpu=false)
13+
optimiser_state = Flux.setup(optimiser, model)
14+
if gpu # Pass model to GPU (if available) upon creation
15+
return Approximator(gpu(model), gpu(optimiser_state))
16+
else
17+
return Approximator(model, optimiser_state)
18+
end
19+
end
20+
21+
Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
22+
23+
@functor Approximator (model,)
24+
25+
forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
26+
27+
RLBase.optimise!(A::Approximator, grad) =
28+
Flux.Optimise.update!(A.optimiser_state, A.model, grad)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
include("abstract_learner.jl")
2+
include("approximator.jl")
23
include("tabular_approximator.jl")
34
include("target_network.jl")
Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
export TabularApproximator, TabularVApproximator, TabularQApproximator
22

3-
using Flux: gpu
3+
const TabularApproximator = Approximator{A,O} where {A<:AbstractArray,O}
4+
const TabularQApproximator = Approximator{A,O} where {A<:AbstractArray,O}
5+
const TabularVApproximator = Approximator{A,O} where {A<:AbstractVector,O}
46

57
"""
68
TabularApproximator(table<:AbstractArray, opt)
@@ -11,15 +13,10 @@ For `table` of 2-d, it will serve as a state-action value approximator.
1113
!!! warning
1214
For `table` of 2-d, the first dimension is action and the second dimension is state.
1315
"""
14-
# TODO: add back missing AbstractApproximator
15-
struct TabularApproximator{N,A,O} <: AbstractLearner
16-
table::A
17-
optimizer::O
18-
function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O}
19-
n = ndims(table)
20-
n <= 2 || throw(ArgumentError("the dimension of table must be <= 2"))
21-
new{n,A,O}(table, opt)
22-
end
16+
function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O}
17+
n = ndims(table)
18+
n <= 2 || throw(ArgumentError("the dimension of table must be <= 2"))
19+
TabularApproximator{A,O}(table, opt)
2320
end
2421

2522
TabularVApproximator(; n_state, init = 0.0, opt = InvDecay(1.0)) =
@@ -29,21 +26,22 @@ TabularQApproximator(; n_state, n_action, init = 0.0, opt = InvDecay(1.0)) =
2926
TabularApproximator(fill(init, n_action, n_state), opt)
3027

3128
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
32-
function forward(L::TabularApproximator, env::E) where {E <: AbstractEnv}
29+
function forward(L::Approximator{A, Any}, env::E) where {A <:AbstractArray, E <: AbstractEnv}
3330
env |> state |> (x -> forward(L, x))
3431
end
3532

3633
RLCore.forward(
37-
app::TabularApproximator{1,R,O},
34+
app::Approximator{R,O},
3835
s::I,
39-
) where {R<:AbstractArray,O,I<:Integer} = @views app.table[s]
36+
) where {R<:AbstractVector,O} = @views app.model[s]
4037

4138
RLCore.forward(
42-
app::TabularApproximator{2,R,O},
39+
app::Approximator{R,O},
4340
s::I,
44-
) where {R<:AbstractArray,O,I<:Integer} = @views app.table[:, s]
41+
) where {R<:AbstractArray,O} = @views app.model[:, s]
42+
4543
RLCore.forward(
46-
app::TabularApproximator{2,R,O},
44+
app::Approximator{R,O},
4745
s::I1,
4846
a::I2,
49-
) where {R<:AbstractArray,O,I1<:Integer,I2<:Integer} = @views app.table[a, s]
47+
) where {R<:AbstractArray,O} = @views app.model[a, s]

0 commit comments

Comments
 (0)