1- using Functors: @functor
21import Flux
32import Flux. onehotbatch
43using ChainRulesCore: ignore_derivatives
@@ -18,7 +17,7 @@ Base.@kwdef struct ActorCritic{A,C,O}
1817 critic:: C
1918end
2019
21- @functor ActorCritic
20+ Flux . @layer ActorCritic
2221
2322# ####
2423# GaussianNetwork
5352
5453GaussianNetwork (pre, μ, σ; squash = identity) = GaussianNetwork (pre, μ, σ, 0.0f0 , Inf32 , squash)
5554
56- @functor GaussianNetwork
55+ Flux . @layer GaussianNetwork
5756
5857"""
5958This function is compatible with a multidimensional action space.
142141
143142SoftGaussianNetwork (pre, μ, σ) = SoftGaussianNetwork (pre, μ, σ, 0.0f0 , Inf32 )
144143
145- @functor SoftGaussianNetwork
144+ Flux . @layer SoftGaussianNetwork
146145
147146"""
148147This function is compatible with a multidimensional action space.
@@ -225,7 +224,7 @@ Base.@kwdef mutable struct CovGaussianNetwork{P,U,S}
225224 Σ:: S
226225end
227226
228- @functor CovGaussianNetwork
227+ Flux . @layer CovGaussianNetwork
229228
230229"""
231230 (model::CovGaussianNetwork)(rng::AbstractRNG, state::AbstractArray{<:Any, 3}; is_sampling::Bool=false, is_return_log_prob::Bool=false)
@@ -407,7 +406,7 @@ mutable struct CategoricalNetwork{P}
407406 model:: P
408407end
409408
410- @functor CategoricalNetwork
409+ Flux . @layer CategoricalNetwork
411410
412411function (model:: CategoricalNetwork )(rng:: AbstractRNG , state:: AbstractArray ; is_sampling:: Bool = false , is_return_log_prob:: Bool = false )
413412 logits = model. model (state) # may be 1-3 dimensional
@@ -514,7 +513,7 @@ Base.@kwdef struct DuelingNetwork{B,V,A}
514513 adv:: A
515514end
516515
517- Flux. @functor DuelingNetwork
516+ Flux. @layer DuelingNetwork
518517
519518function (m:: DuelingNetwork )(state)
520519 x = m. base (state)
@@ -544,7 +543,7 @@ Base.@kwdef struct PerturbationNetwork{N}
544543 ϕ:: Float32 = 0.05f0
545544end
546545
547- Flux. @functor PerturbationNetwork
546+ Flux. @layer PerturbationNetwork
548547
549548"""
550549This function accepts `state` and `action`, and then outputs actions after disturbance.
@@ -570,7 +569,7 @@ Base.@kwdef struct VAE{E,D}
570569 latent_dims:: Int
571570end
572571
573- Flux. @functor VAE
572+ Flux. @layer VAE
574573
575574function (model:: VAE )(rng:: AbstractRNG , state, action)
576575 μ, σ = model. encoder (vcat (state, action))
0 commit comments