Skip to content

Commit 2e1af7f

Browse files
Fix PPO per #1007 (#1013)
1 parent 98ebba3 commit 2e1af7f

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

  • src/ReinforcementLearningZoo/src/algorithms/policy_gradient

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ end
158158

159159
function RLBase.prob(p::PPOPolicy{<:ActorCritic,Categorical}, state::AbstractArray, mask)
160160
logits = p.approximator.actor(send_to_device(device(p.approximator), state))
161+
mask = send_to_device(device(p.approximator), mask)
162+
161163
if !isnothing(mask)
162164
logits .+= ifelse.(mask, 0.0f0, typemin(Float32))
163165
end

0 commit comments

Comments
 (0)