Skip to content

Commit b735781

Browse files
authored
Merge pull request #253 from huangshiyu13/main
add test atari
2 parents 85e7803 + ca32597 commit b735781

1 file changed

Lines changed: 74 additions & 0 deletions

File tree

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
""""""
2+
3+
import os
4+
import sys
5+
6+
import numpy as np
7+
import pytest
8+
9+
from openrl.configs.config import create_config_parser
10+
from openrl.envs.common import make
11+
from openrl.envs.wrappers.atari_wrappers import (
12+
ClipRewardEnv,
13+
FireResetEnv,
14+
NoopResetEnv,
15+
WarpFrame,
16+
)
17+
from openrl.envs.wrappers.image_wrappers import TransposeImage
18+
from openrl.envs.wrappers.monitor import Monitor
19+
from openrl.modules.common import PPONet as Net
20+
from openrl.runners.common import PPOAgent as Agent
21+
22+
env_wrappers = [
23+
Monitor,
24+
NoopResetEnv,
25+
FireResetEnv,
26+
WarpFrame,
27+
ClipRewardEnv,
28+
TransposeImage,
29+
]
30+
31+
32+
@pytest.fixture(
33+
scope="module",
34+
params=[
35+
"--episode_length 5 --use_recurrent_policy false --vec_info_class.id"
36+
" EPS_RewardInfo --use_valuenorm true --use_adv_normalize true"
37+
" --use_share_model True --entropy_coef 0.01"
38+
],
39+
)
40+
def config(request):
41+
cfg_parser = create_config_parser()
42+
cfg = cfg_parser.parse_args(request.param.split())
43+
return cfg
44+
45+
46+
@pytest.mark.unittest
47+
def test_train_atari(config):
48+
env_num = 2
49+
env = make(
50+
"ALE/Pong-v5",
51+
env_num=env_num,
52+
cfg=config,
53+
asynchronous=True,
54+
env_wrappers=env_wrappers,
55+
)
56+
net = Net(env, cfg=config)
57+
agent = Agent(net)
58+
agent.train(total_time_steps=30)
59+
agent.save("./ppo_agent/")
60+
agent.load("./ppo_agent/")
61+
agent.set_env(env)
62+
obs, info = env.reset(seed=0)
63+
step = 0
64+
while step < 5:
65+
action, _ = agent.act(obs, deterministic=True)
66+
obs, r, done, info = env.step(action)
67+
if np.any(done):
68+
break
69+
step += 1
70+
env.close()
71+
72+
73+
if __name__ == "__main__":
74+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 commit comments

Comments
 (0)