Skip to content

Commit 537f822

Browse files
committed
arena add test more envs
1 parent 0228e51 commit 537f822

8 files changed

Lines changed: 264 additions & 11 deletions

File tree

examples/arena/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
```bash
55
pip install "openrl[selfplay]"
6+
pip install "pettingzoo[mpe]","pettingzoo[butterfly]"
67
```
78

89
### Usage
@@ -15,3 +16,11 @@ python run_arena.py
1516
### Evaluate Google Research Football submissions for JiDi locally
1617

1718
If you want to evaluate your Google Research Football submissions for JiDi locally, please try to use tizero as illustrated [here](foothttps://github.com/OpenRL-Lab/TiZero#evaluate-jidi-submissions-locally).
19+
20+
### Evaluate more environments
21+
22+
We also provide a script to evaluate more environments, including MPE, Go, Texas Holdem, Butterfly. You can run the script as follows:
23+
24+
```shell
25+
python evaluate_more_envs.py
26+
```
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
#!/usr/bin/env python
20+
# -*- coding: utf-8 -*-
21+
# Copyright 2023 The OpenRL Authors.
22+
#
23+
# Licensed under the Apache License, Version 2.0 (the "License");
24+
# you may not use this file except in compliance with the License.
25+
# You may obtain a copy of the License at
26+
#
27+
# https://www.apache.org/licenses/LICENSE-2.0
28+
#
29+
# Unless required by applicable law or agreed to in writing, software
30+
# distributed under the License is distributed on an "AS IS" BASIS,
31+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32+
# See the License for the specific language governing permissions and
33+
# limitations under the License.
34+
35+
""""""
36+
37+
from pettingzoo.butterfly import cooperative_pong_v5
38+
from pettingzoo.classic import connect_four_v3, go_v5, texas_holdem_no_limit_v6
39+
from pettingzoo.mpe import simple_push_v3
40+
41+
from examples.custom_env.rock_paper_scissors import RockPaperScissors
42+
from openrl.arena import make_arena
43+
from openrl.arena.agents.local_agent import LocalAgent
44+
from openrl.envs.PettingZoo.registration import register
45+
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
46+
47+
48+
def ConnectFourEnv(render_mode, **kwargs):
49+
return connect_four_v3.env(render_mode)
50+
51+
52+
def RockPaperScissorsEnv(render_mode, **kwargs):
53+
return RockPaperScissors(render_mode)
54+
55+
56+
def GoEnv(render_mode, **kwargs):
57+
return go_v5.env(render_mode=render_mode, board_size=5, komi=7.5)
58+
59+
60+
def TexasHoldemEnv(render_mode, **kwargs):
61+
return texas_holdem_no_limit_v6.env(render_mode=render_mode)
62+
63+
64+
# MPE
65+
def SimplePushEnv(render_mode, **kwargs):
66+
return simple_push_v3.env(render_mode=render_mode)
67+
68+
69+
def CooperativePongEnv(render_mode, **kwargs):
70+
return cooperative_pong_v5.env(render_mode=render_mode)
71+
72+
73+
def register_new_envs():
74+
new_env_dict = {
75+
"connect_four_v3": ConnectFourEnv,
76+
"RockPaperScissors": RockPaperScissorsEnv,
77+
"go_v5": GoEnv,
78+
"texas_holdem_no_limit_v6": TexasHoldemEnv,
79+
"simple_push_v3": SimplePushEnv,
80+
"cooperative_pong_v5": CooperativePongEnv,
81+
}
82+
83+
for env_id, env in new_env_dict.items():
84+
register(env_id, env)
85+
return new_env_dict.keys()
86+
87+
88+
def run_arena(
89+
env_id: str,
90+
parallel: bool = True,
91+
seed=0,
92+
total_games: int = 10,
93+
max_game_onetime: int = 5,
94+
):
95+
env_wrappers = [RecordWinner]
96+
97+
arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False)
98+
99+
agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent")
100+
agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent")
101+
102+
arena.reset(
103+
agents={"agent1": agent1, "agent2": agent2},
104+
total_games=total_games,
105+
max_game_onetime=max_game_onetime,
106+
seed=seed,
107+
)
108+
result = arena.run(parallel=parallel)
109+
arena.close()
110+
print(result)
111+
return result
112+
113+
114+
def test_new_envs():
115+
env_ids = register_new_envs()
116+
seed = 0
117+
for env_id in env_ids:
118+
run_arena(env_id=env_id, seed=seed, parallel=False, total_games=1)
119+
120+
121+
if __name__ == "__main__":
122+
test_new_envs()

examples/custom_env/rock_paper_scissors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def step(self, action):
182182
# handles stepping an agent which is already dead
183183
# accepts a None action for the one agent, and moves the agent_selection to
184184
# the next dead agent, or if there are no more dead agents, to the next live agent
185+
action = None
185186
self._was_dead_step(action)
186187
return
187188

openrl/arena/games/two_player_game.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ def default_dispatch_func(
3131
players: List[str],
3232
agent_names: List[str],
3333
) -> Dict[str, str]:
34-
assert len(players) == len(
35-
agent_names
36-
), "The number of players must be equal to the number of agents."
34+
assert len(players) == len(agent_names), (
35+
f"The number of players {len(players)} must be equal to the number of"
36+
f" agents: {len(agent_names)}."
37+
)
3738
assert len(players) == 2, "The number of players must be equal to 2."
3839
np_random.shuffle(agent_names)
3940
return dict(zip(players, agent_names))
@@ -49,20 +50,21 @@ def _run(self, env_fn: Callable, agents: List[BaseAgent]):
4950
for player, agent in player2agent.items():
5051
agent.reset(env, player)
5152
result = {}
53+
truncation_dict = {}
5254
while True:
5355
termination = False
5456
info = {}
5557
for player_name in env.agent_iter():
5658
observation, reward, termination, truncation, info = env.last()
57-
58-
if termination:
59+
truncation_dict[player_name] = truncation
60+
if termination or all(truncation_dict.values()):
5961
break
6062
action = player2agent[player_name].act(
6163
player_name, observation, reward, termination, truncation, info
6264
)
6365
env.step(action)
6466

65-
if termination:
67+
if termination or all(truncation_dict.values()):
6668
assert "winners" in info, "The game is terminated but no winners."
6769
assert "losers" in info, "The game is terminated but no losers."
6870

openrl/envs/wrappers/pettingzoo_wrappers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ def last(self, observe: bool = True):
9696

9797
winners = None
9898
losers = None
99+
99100
for agent in self.terminations:
100-
if self.terminations[agent]:
101+
if self.terminations[agent] or all(self.truncations):
101102
if winners is None:
102103
winners = self.get_winners()
103104
losers = [player for player in self.agents if player not in winners]

openrl/selfplay/opponents/random_opponent.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,20 @@ def _sample_random_action(
4747
action = []
4848

4949
for obs, space in zip(observation, action_space):
50-
mask = obs.get("action_mask", None)
51-
action.append(space.sample(mask))
50+
if termination or truncation:
51+
action.append(None)
52+
else:
53+
if isinstance(obs, dict):
54+
mask = obs.get("action_mask", None)
55+
else:
56+
mask = None
57+
action.append(space.sample(mask))
5258
else:
53-
mask = observation.get("action_mask", None)
54-
action = action_space.sample(mask)
59+
if termination or truncation:
60+
action = None
61+
else:
62+
mask = observation.get("action_mask", None)
63+
action = action_space.sample(mask)
5564
return action
5665

5766
def _load(self, opponent_path: Union[str, Path]):

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,13 @@ def get_extra_requires() -> dict:
7171
"evaluate",
7272
],
7373
"selfplay": ["ray[default]", "ray[serve]", "pettingzoo[classic]", "trueskill"],
74+
"selfplay_test": ["pettingzoo[mpe]", "pettingzoo[butterfly]"],
7475
"retro": ["gym-retro"],
7576
"super_mario": ["gym-super-mario-bros"],
7677
"atari": ["gymnasium[atari]", "gymnasium[accept-rom-license]"],
7778
}
7879
req["test"].extend(req["selfplay"])
80+
req["test"].extend(req["selfplay_test"])
7981
req["test"].extend(req["atari"])
8082
req["test"].extend(req["nlp_test"])
8183
return req

tests/test_arena/test_new_envs.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
import os
19+
import sys
20+
21+
import pytest
22+
from pettingzoo.butterfly import cooperative_pong_v5
23+
from pettingzoo.classic import connect_four_v3, go_v5, texas_holdem_no_limit_v6
24+
from pettingzoo.mpe import simple_push_v3
25+
26+
from examples.custom_env.rock_paper_scissors import RockPaperScissors
27+
from openrl.arena import make_arena
28+
from openrl.arena.agents.local_agent import LocalAgent
29+
from openrl.envs.PettingZoo.registration import register
30+
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
31+
32+
33+
def ConnectFourEnv(render_mode, **kwargs):
34+
return connect_four_v3.env(render_mode)
35+
36+
37+
def RockPaperScissorsEnv(render_mode, **kwargs):
38+
return RockPaperScissors(render_mode)
39+
40+
41+
def GoEnv(render_mode, **kwargs):
42+
return go_v5.env(render_mode=render_mode, board_size=5, komi=7.5)
43+
44+
45+
def TexasHoldemEnv(render_mode, **kwargs):
46+
return texas_holdem_no_limit_v6.env(render_mode=render_mode)
47+
48+
49+
# MPE
50+
def SimplePushEnv(render_mode, **kwargs):
51+
return simple_push_v3.env(render_mode=render_mode)
52+
53+
54+
def CooperativePongEnv(render_mode, **kwargs):
55+
return cooperative_pong_v5.env(render_mode=render_mode)
56+
57+
58+
def register_new_envs():
59+
new_env_dict = {
60+
"connect_four_v3": ConnectFourEnv,
61+
"RockPaperScissors": RockPaperScissorsEnv,
62+
"go_v5": GoEnv,
63+
"texas_holdem_no_limit_v6": TexasHoldemEnv,
64+
"simple_push_v3": SimplePushEnv,
65+
"cooperative_pong_v5": CooperativePongEnv,
66+
}
67+
68+
for env_id, env in new_env_dict.items():
69+
register(env_id, env)
70+
return new_env_dict.keys()
71+
72+
73+
def run_arena(
74+
env_id: str,
75+
parallel: bool = True,
76+
seed=0,
77+
total_games: int = 10,
78+
max_game_onetime: int = 5,
79+
):
80+
env_wrappers = [RecordWinner]
81+
82+
arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False)
83+
84+
agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
85+
agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
86+
87+
arena.reset(
88+
agents={"agent1": agent1, "agent2": agent2},
89+
total_games=total_games,
90+
max_game_onetime=max_game_onetime,
91+
seed=seed,
92+
)
93+
result = arena.run(parallel=parallel)
94+
arena.close()
95+
return result
96+
97+
98+
@pytest.mark.unittest
99+
def test_new_envs():
100+
env_ids = register_new_envs()
101+
seed = 0
102+
for env_id in env_ids:
103+
run_arena(env_id=env_id, seed=seed, parallel=False, total_games=1)
104+
105+
106+
if __name__ == "__main__":
107+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 commit comments

Comments
 (0)