Skip to content

Commit 55f60b0

Browse files
Drop Functors and use Flux.@layer (#1048)
1 parent d10fae7 commit 55f60b0

11 files changed

Lines changed: 30 additions & 38 deletions

File tree

src/ReinforcementLearningCore/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1313
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
14-
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1514
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1615
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1716
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
@@ -33,7 +32,6 @@ Crayons = "4"
3332
Distributions = "0.25"
3433
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
3534
Flux = "0.14"
36-
Functors = "0.1, 0.2, 0.3, 0.4"
3735
GPUArrays = "8, 9, 10"
3836
Metal = "1.0"
3937
ProgressMeter = "1"

src/ReinforcementLearningCore/src/policies/agent/agent_base.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
export Agent
22

33
using Base.Threads: @spawn
4-
5-
using Functors: @functor
4+
using Flux
65
import Base.push!
76

87
abstract type AbstractAgent <: AbstractPolicy end
@@ -41,7 +40,7 @@ RLBase.optimise!(::AsyncTrajectoryStyle, agent::AbstractAgent, stage::S) where {
4140
#by default, optimise does nothing at all stage
4241
function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajectory::Trajectory) end
4342

44-
@functor Agent (policy,)
43+
Flux.@layer Agent trainable=(policy,)
4544

4645
function Base.push!(agent::Agent, ::PreEpisodeStage, env::AbstractEnv)
4746
push!(agent.trajectory, (state = state(env),))

src/ReinforcementLearningCore/src/policies/agent/offline_agent.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
export OfflineAgent, OfflineBehavior
22

3+
using Flux
4+
35
"""
46
OfflineBehavior(; agent:: Union{<:Agent, Nothing}, steps::Int, reset_condition)
57
@@ -49,7 +51,7 @@ struct OfflineAgent{P<:AbstractPolicy,T<:Trajectory,B<:OfflineBehavior} <: Abstr
4951
end
5052

5153
OfflineAgent(; policy, trajectory, offline_behavior=OfflineBehavior()) = OfflineAgent(policy, trajectory, offline_behavior)
52-
@functor OfflineAgent (policy,)
54+
Flux.@layer OfflineAgent trainable=(policy,)
5355

5456
Base.push!(::OfflineAgent{P,T,<:OfflineBehavior{Nothing}}, ::PreExperimentStage, env::AbstractEnv) where {P,T} = nothing
5557
#fills the trajectory with interactions generated with the behavior_agent at the PreExperimentStage.

src/ReinforcementLearningCore/src/policies/learners/abstract_learner.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
export AbstractLearner, Approximator
22

33
using Flux
4-
using Functors: @functor
54

65
abstract type AbstractLearner end
76

src/ReinforcementLearningCore/src/policies/learners/approximator.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ end
3636

3737
Approximator(model, optimiser::Flux.Optimise.AbstractOptimiser; use_gpu=false) = Approximator(model=model, optimiser=optimiser, use_gpu=use_gpu)
3838

39-
Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
40-
41-
@functor Approximator (model,)
39+
Flux.@layer Approximator trainable=(model,)
4240

4341
forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
4442
forward(A::Approximator, env::E) where {E <: AbstractEnv} = env |> state |> (x -> forward(A, x))

src/ReinforcementLearningCore/src/policies/learners/target_network.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
export Approximator, TargetNetwork, target, model
22

3-
using Flux: gpu
4-
3+
using Flux
54

65
target(ap::Approximator) = ap.model #see TargetNetwork
76
model(ap::Approximator) = ap.model #see TargetNetwork
@@ -61,9 +60,7 @@ function TargetNetwork(network::Approximator; sync_freq = 1, ρ = 0f0, use_gpu =
6160
return TargetNetwork(network, target, sync_freq, ρ, 0)
6261
end
6362

64-
@functor TargetNetwork (network, target)
65-
66-
Flux.trainable(model::TargetNetwork) = (model.network,)
63+
Flux.@layer TargetNetwork trainable=(network,)
6764

6865
forward(tn::TargetNetwork, args...) = forward(tn.network, args...)
6966

src/ReinforcementLearningCore/src/policies/q_based_policy.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export QBasedPolicy
22

3-
using Functors: @functor
3+
using Flux
44

55
"""
66
QBasedPolicy(;learner, explorer)
@@ -17,7 +17,7 @@ Base.@kwdef mutable struct QBasedPolicy{L,E} <: AbstractPolicy
1717
explorer::E
1818
end
1919

20-
@functor QBasedPolicy (learner,)
20+
Flux.@layer QBasedPolicy trainable=(learner,)
2121

2222
function RLBase.plan!(p::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv}
2323
RLBase.plan!(p.explorer, p.learner, env)

src/ReinforcementLearningCore/src/utils/networks.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using Functors: @functor
21
import Flux
32
import Flux.onehotbatch
43
using ChainRulesCore: ignore_derivatives
@@ -18,7 +17,7 @@ Base.@kwdef struct ActorCritic{A,C,O}
1817
critic::C
1918
end
2019

21-
@functor ActorCritic
20+
Flux.@layer ActorCritic
2221

2322
#####
2423
# GaussianNetwork
@@ -53,7 +52,7 @@ end
5352

5453
GaussianNetwork(pre, μ, σ; squash = identity) = GaussianNetwork(pre, μ, σ, 0.0f0, Inf32, squash)
5554

56-
@functor GaussianNetwork
55+
Flux.@layer GaussianNetwork
5756

5857
"""
5958
This function is compatible with a multidimensional action space.
@@ -142,7 +141,7 @@ end
142141

143142
SoftGaussianNetwork(pre, μ, σ) = SoftGaussianNetwork(pre, μ, σ, 0.0f0, Inf32)
144143

145-
@functor SoftGaussianNetwork
144+
Flux.@layer SoftGaussianNetwork
146145

147146
"""
148147
This function is compatible with a multidimensional action space.
@@ -225,7 +224,7 @@ Base.@kwdef mutable struct CovGaussianNetwork{P,U,S}
225224
Σ::S
226225
end
227226

228-
@functor CovGaussianNetwork
227+
Flux.@layer CovGaussianNetwork
229228

230229
"""
231230
(model::CovGaussianNetwork)(rng::AbstractRNG, state::AbstractArray{<:Any, 3}; is_sampling::Bool=false, is_return_log_prob::Bool=false)
@@ -407,7 +406,7 @@ mutable struct CategoricalNetwork{P}
407406
model::P
408407
end
409408

410-
@functor CategoricalNetwork
409+
Flux.@layer CategoricalNetwork
411410

412411
function (model::CategoricalNetwork)(rng::AbstractRNG, state::AbstractArray; is_sampling::Bool=false, is_return_log_prob::Bool = false)
413412
logits = model.model(state) #may be 1-3 dimensional
@@ -514,7 +513,7 @@ Base.@kwdef struct DuelingNetwork{B,V,A}
514513
adv::A
515514
end
516515

517-
Flux.@functor DuelingNetwork
516+
Flux.@layer DuelingNetwork
518517

519518
function (m::DuelingNetwork)(state)
520519
x = m.base(state)
@@ -544,7 +543,7 @@ Base.@kwdef struct PerturbationNetwork{N}
544543
ϕ::Float32 = 0.05f0
545544
end
546545

547-
Flux.@functor PerturbationNetwork
546+
Flux.@layer PerturbationNetwork
548547

549548
"""
550549
This function accepts `state` and `action`, and then outputs actions after disturbance.
@@ -570,7 +569,7 @@ Base.@kwdef struct VAE{E,D}
570569
latent_dims::Int
571570
end
572571

573-
Flux.@functor VAE
572+
Flux.@layer VAE
574573

575574
function (model::VAE)(rng::AbstractRNG, state, action)
576575
μ, σ = model.encoder(vcat(state, action))

src/ReinforcementLearningCore/test/core/hooks.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
function test_run!(hook::AbstractHook)
3434
hook_ = deepcopy(hook)
35-
run(RandomPolicy(), RandomWalk1D(), StopAfterNEpisodes(10), hook_)
35+
run(RandomPolicy(), RandomWalk1D(), StopAfterNEpisodes(100), hook_)
3636
return hook_
3737
end
3838

@@ -49,7 +49,7 @@ end
4949

5050
for h in (h_1, h_2, h_3, h_4, h_5)
5151
h_ = test_run!(h)
52-
@test length(h_.rewards) == 10
52+
@test length(h_.rewards) == 100
5353
@test sum(h_.rewards .== 1) > 0
5454
@test sum(h_.rewards .== -1) > 0
5555

@@ -77,11 +77,11 @@ end
7777
h_1 = TimePerStep()
7878
h_2 = TimePerStep{Float32}()
7979

80-
sleep_vect = [0.01, 0.02, 0.03]
80+
sleep_vect = [0.05, 0.05, 0.05]
8181
for h in (h_1, h_2)
8282
push!(h, PostActStage(), 1, 1)
8383
[(sleep(i); push!(h, PostActStage(), 1, 1)) for i in sleep_vect]
84-
@test all(0.1 .> h.times[2:end] .> 0)
84+
@test all(0.2 .> h.times[2:end] .> 0)
8585
test_noop!(h, stages=[PreActStage(), PreEpisodeStage(), PostEpisodeStage(), PreExperimentStage(), PostExperimentStage()])
8686
end
8787
end
@@ -115,8 +115,8 @@ end
115115

116116
for h in (h_1, h_2, h_3)
117117
h_ = test_run!(h)
118-
@test length(h_.rewards) == 10
119-
@test sum(abs.(sum.(h_.rewards))) == 10
118+
@test length(h_.rewards) == 100
119+
@test sum(abs.(sum.(h_.rewards))) == 100
120120
@test length(unique(length.(h_.rewards))) > 1
121121
test_noop!(h, stages=[PreActStage(), PostEpisodeStage(), PreExperimentStage(), PostExperimentStage()])
122122
end

src/ReinforcementLearningCore/test/utils/networks.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ import ReinforcementLearningBase: RLBase
1919
@testset "NeuralNetworkApproximator" begin
2020
NN = NeuralNetworkApproximator(; model = Dense(2, 3), optimizer = Descent())
2121
22-
q_values = NN(rand(2))
22+
q_values = NN(rand(Float32, 2))
2323
@test size(q_values) == (3,)
2424
2525
gs = gradient(params(NN)) do
26-
sum(NN(rand(2, 5)))
26+
sum(NN(rand(Float32, 2, 5)))
2727
end
2828
2929
old_params = deepcopy(collect(params(NN).params))
@@ -47,15 +47,15 @@ import ReinforcementLearningBase: RLBase
4747
D = ac.actor.model |> gpu |> device
4848
@test D === device(ac) === device(ac.actor) == device(ac.critic)
4949
50-
A = send_to_device(D, rand(3))
50+
A = send_to_device(D, rand(Float32, 3))
5151
ac.actor(A)
5252
ac.critic(A)
5353
end=#
5454

5555
@testset "GaussianNetwork" begin
5656
@testset "On CPU" begin
5757
gn = GaussianNetwork(Dense(20,15), Dense(15,10), Dense(15,10, softplus))
58-
state = rand(Float32,20,3) #batch of 3 states
58+
state = rand(Float32, 20, 3) #batch of 3 states
5959
@testset "Correctness of outputs" begin
6060
m, L = gn(state)
6161
@test size(m) == size(L) == (10,3)
@@ -115,7 +115,7 @@ import ReinforcementLearningBase: RLBase
115115
if (@isdefined CUDA) && CUDA.functional()
116116
CUDA.allowscalar(false)
117117
gn = GaussianNetwork(Dense(20,15), Dense(15,10), Dense(15,10, softplus)) |> gpu
118-
state = rand(20,3) |> gpu #batch of 3 states
118+
state = rand(Float32, 20,3) |> gpu #batch of 3 states
119119
@testset "Forward pass compatibility" begin
120120
@test Flux.params(gn) == Flux.Params([gn.pre.weight, gn.pre.bias, gn.μ.weight, gn.μ.bias, gn.σ.weight, gn.σ.bias])
121121
m, L = gn(state)

0 commit comments

Comments
 (0)