Skip to content

Commit 7df4dc1

Browse files
authored
add JiDi evaluation support
add JiDi evaluation support
2 parents f52859d + 06917dc commit 7df4dc1

18 files changed

Lines changed: 1159 additions & 43 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,4 @@ opponent_pool
159159
!/examples/selfplay/opponent_templates/tictactoe_opponent/info.json
160160
wandb_run
161161
examples/dmc/new.gif
162+
/examples/snake/submissions/rl/actor_2000.pth

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ test:
1212
lint:
1313
$(call check_install, ruff)
1414
ruff ${PYTHON_FILES} --select=E9,F63,F7,F82 --show-source
15-
ruff ${PYTHON_FILES} --exit-zero | grep -v '501\|405\|401\|402\|403'
15+
ruff ${PYTHON_FILES} --exit-zero | grep -v '501\|405\|401\|402\|403\|722'
1616

1717
format:
1818
$(call check_install, isort)

examples/snake/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ This is the example for the snake game.
77
python train_selfplay.py
88
```
99

10+
## Evaluate JiDi submissions locally
11+
12+
```bash
13+
python jidi_eval.py
14+
```
1015

1116
## Submit to JiDi
1217

examples/snake/jidi_eval.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
from openrl.arena import make_arena
20+
from openrl.arena.agents.jidi_agent import JiDiAgent
21+
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
22+
23+
24+
def run_arena(
25+
render: bool = False,
26+
parallel: bool = True,
27+
seed=0,
28+
total_games: int = 10,
29+
max_game_onetime: int = 5,
30+
):
31+
env_wrappers = [RecordWinner]
32+
33+
player_num = 3
34+
arena = make_arena(
35+
f"snakes_{player_num}v{player_num}", env_wrappers=env_wrappers, render=render
36+
)
37+
38+
agent1 = JiDiAgent("./submissions/rule_v1", player_num=player_num)
39+
agent2 = JiDiAgent("./submissions/rl", player_num=player_num)
40+
41+
arena.reset(
42+
agents={"agent1": agent1, "agent2": agent2},
43+
total_games=total_games,
44+
max_game_onetime=max_game_onetime,
45+
seed=seed,
46+
)
47+
result = arena.run(parallel=parallel)
48+
arena.close()
49+
print(result)
50+
return result
51+
52+
53+
if __name__ == "__main__":
54+
run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5)

examples/snake/submissions/random_agent/submission.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ def my_controller(observation, action_space, is_act_continuous):
2626
for i in range(len(action_space)):
2727
player = sample_single_dim(action_space[i], is_act_continuous)
2828
joint_action.append(player)
29+
2930
return joint_action
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Download actor weight
2+
3+
Please download [actor_2000.pth](https://github.com/CarlossShi/Competition_3v3snakes/tree/master/agent/rl) before use this code.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import torch
7+
from torch import nn
8+
from torch.distributions import Categorical
9+
10+
HIDDEN_SIZE = 256
11+
device = torch.device("cpu")
12+
13+
from typing import Union
14+
15+
Activation = Union[str, nn.Module]
16+
17+
_str_to_activation = {
18+
"relu": torch.nn.ReLU(),
19+
"tanh": nn.Tanh(),
20+
"identity": nn.Identity(),
21+
"softmax": nn.Softmax(dim=-1),
22+
}
23+
24+
25+
def mlp(
26+
sizes, activation: Activation = "relu", output_activation: Activation = "identity"
27+
):
28+
if isinstance(activation, str):
29+
activation = _str_to_activation[activation]
30+
if isinstance(output_activation, str):
31+
output_activation = _str_to_activation[output_activation]
32+
33+
layers = []
34+
for i in range(len(sizes) - 1):
35+
act = activation if i < len(sizes) - 2 else output_activation
36+
layers += [nn.Linear(sizes[i], sizes[i + 1]), act]
37+
return nn.Sequential(*layers)
38+
39+
40+
def get_surrounding(state, width, height, x, y):
41+
surrounding = [
42+
state[(y - 1) % height][x], # up
43+
state[(y + 1) % height][x], # down
44+
state[y][(x - 1) % width], # left
45+
state[y][(x + 1) % width],
46+
] # right
47+
48+
return surrounding
49+
50+
51+
def make_grid_map(
52+
board_width, board_height, beans_positions: list, snakes_positions: dict
53+
):
54+
snakes_map = [[[0] for _ in range(board_width)] for _ in range(board_height)]
55+
for index, pos in snakes_positions.items():
56+
for p in pos:
57+
snakes_map[p[0]][p[1]][0] = index
58+
59+
for bean in beans_positions:
60+
snakes_map[bean[0]][bean[1]][0] = 1
61+
62+
return snakes_map
63+
64+
65+
# Self position: 0:head_x; 1:head_y
66+
# Head surroundings: 2:head_up; 3:head_down; 4:head_left; 5:head_right
67+
# Beans positions: (6, 7) (8, 9) (10, 11) (12, 13) (14, 15)
68+
# Other snake positions: (16, 17) (18, 19) (20, 21) (22, 23) (24, 25) -- (other_x - self_x, other_y - self_y)
69+
def get_observations(state, agents_index, obs_dim, height, width):
70+
state_copy = state.copy()
71+
board_width = state_copy["board_width"]
72+
board_height = state_copy["board_height"]
73+
beans_positions = state_copy[1]
74+
snakes_positions = {
75+
key: state_copy[key] for key in state_copy.keys() & {2, 3, 4, 5, 6, 7}
76+
}
77+
snakes_positions_list = []
78+
for key, value in snakes_positions.items():
79+
snakes_positions_list.append(value)
80+
snake_map = make_grid_map(
81+
board_width, board_height, beans_positions, snakes_positions
82+
)
83+
state_ = np.array(snake_map)
84+
state_ = np.squeeze(state_, axis=2)
85+
86+
observations = np.zeros((3, obs_dim))
87+
snakes_position = np.array(snakes_positions_list, dtype=object)
88+
beans_position = np.array(beans_positions, dtype=object).flatten()
89+
for i, element in enumerate(agents_index):
90+
# # self head position
91+
observations[i][:2] = snakes_positions_list[element][0][:]
92+
93+
# head surroundings
94+
head_x = snakes_positions_list[element][0][1]
95+
head_y = snakes_positions_list[element][0][0]
96+
97+
head_surrounding = get_surrounding(state_, width, height, head_x, head_y)
98+
observations[i][2:6] = head_surrounding[:]
99+
100+
# beans positions
101+
observations[i][6:16] = beans_position[:]
102+
103+
# other snake positions
104+
snake_heads = np.array([snake[0] for snake in snakes_position])
105+
snake_heads = np.delete(snake_heads, i, 0)
106+
observations[i][16:] = snake_heads.flatten()[:]
107+
return observations
108+
109+
110+
class Actor(nn.Module):
111+
def __init__(self, obs_dim, act_dim, num_agents, args, output_activation="softmax"):
112+
super().__init__()
113+
114+
self.obs_dim = obs_dim
115+
self.act_dim = act_dim
116+
self.num_agents = num_agents
117+
118+
self.args = args
119+
120+
sizes_prev = [obs_dim, HIDDEN_SIZE]
121+
sizes_post = [HIDDEN_SIZE, HIDDEN_SIZE, act_dim]
122+
123+
self.prev_dense = mlp(sizes_prev)
124+
self.post_dense = mlp(sizes_post, output_activation=output_activation)
125+
126+
def forward(self, obs_batch):
127+
out = self.prev_dense(obs_batch)
128+
out = self.post_dense(out)
129+
return out
130+
131+
132+
class RLAgent(object):
133+
def __init__(self, obs_dim, act_dim, num_agent):
134+
self.obs_dim = obs_dim
135+
self.act_dim = act_dim
136+
self.num_agent = num_agent
137+
self.device = device
138+
self.output_activation = "softmax"
139+
self.actor = Actor(obs_dim, act_dim, num_agent, self.output_activation).to(
140+
self.device
141+
)
142+
143+
def choose_action(self, obs):
144+
obs = torch.Tensor([obs]).to(self.device)
145+
logits = self.actor(obs).cpu().detach().numpy()[0]
146+
return logits
147+
148+
def select_action_to_env(self, obs, ctrl_index):
149+
logits = self.choose_action(obs)
150+
actions = logits2action(logits)
151+
action_to_env = to_joint_action(actions, ctrl_index)
152+
return action_to_env
153+
154+
def load_model(self, filename):
155+
self.actor.load_state_dict(torch.load(filename))
156+
157+
158+
def to_joint_action(action, ctrl_index):
159+
joint_action_ = []
160+
action_a = action[ctrl_index]
161+
each = [0] * 4
162+
each[action_a] = 1
163+
joint_action_.append(each)
164+
return joint_action_
165+
166+
167+
def logits2action(logits):
168+
logits = torch.Tensor(logits).to(device)
169+
actions = np.array([Categorical(out).sample().item() for out in logits])
170+
return np.array(actions)
171+
172+
173+
agent = RLAgent(26, 4, 3)
174+
actor_net = os.path.dirname(os.path.abspath(__file__)) + "/actor_2000.pth"
175+
assert Path(actor_net).exists(), (
176+
"actor_2000.pth not exists, please download from:"
177+
" https://github.com/CarlossShi/Competition_3v3snakes/tree/master/agent/rl"
178+
)
179+
agent.load_model(actor_net)
180+
181+
182+
def my_controller(observation_list, action_space_list, is_act_continuous):
183+
obs_dim = 26
184+
obs = observation_list.copy()
185+
board_width = obs["board_width"]
186+
board_height = obs["board_height"]
187+
o_index = obs[
188+
"controlled_snake_index"
189+
] # 2, 3, 4, 5, 6, 7 -> indexs = [0,1,2,3,4,5]
190+
o_indexs_min = 3 if o_index > 4 else 0
191+
indexs = [o_indexs_min, o_indexs_min + 1, o_indexs_min + 2]
192+
observation = get_observations(
193+
obs, indexs, obs_dim, height=board_height, width=board_width
194+
)
195+
actions = agent.select_action_to_env(observation, indexs.index(o_index - 2))
196+
return actions

0 commit comments

Comments
 (0)