Skip to content

Commit aa3a017

Browse files
committed
fixed some bugs
1 parent 22020a1 commit aa3a017

23 files changed

Lines changed: 203 additions & 63 deletions

File tree

examples/selfplay/human_vs_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_human_env(env_num):
2727
env = make(
2828
"tictactoe_v3",
2929
env_num=env_num,
30-
asynchronous=True,
30+
asynchronous=False,
3131
opponent_wrappers=[TictactoeRender, HumanOpponentWrapper],
3232
env_wrappers=[FlattenObservation],
3333
auto_reset=False,

examples/selfplay/opponent_templates/tictactoe_opponent/opponent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def process_obs(self, observation, termination, truncation, info):
4747
return new_obs, termination, truncation, new_info
4848

4949
def process_action(self, action):
50-
return action[0][0][0]
50+
return action[0][0]
5151

5252

5353
class Opponent(NetworkOpponent):

examples/selfplay/tictactoe_utils/tictactoe_render.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
4545

4646
def step(self, action: ActionType) -> None:
4747
result = super().step(action)
48-
self.last_action = action
48+
self.last_action = action[0]
4949
return result
5050

5151
def observe(self, agent: str) -> Optional[ObsType]:

examples/snake/jidi_eval.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def run_arena(
3636
)
3737

3838
agent1 = JiDiAgent("./submissions/rule_v1", player_num=player_num)
39-
agent2 = JiDiAgent("./submissions/rl", player_num=player_num)
39+
if player_num == 3:
40+
agent2 = JiDiAgent("./submissions/rl", player_num=player_num)
41+
else:
42+
agent2 = JiDiAgent("./submissions/rule_v1", player_num=player_num)
4043

4144
arena.reset(
4245
agents={"agent1": agent1, "agent2": agent2},
@@ -51,4 +54,4 @@ def run_arena(
5154

5255

5356
if __name__ == "__main__":
54-
run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5)
57+
run_arena(render=False, parallel=True, seed=0, total_games=10, max_game_onetime=5)

examples/snake/train_selfplay.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ def train():
1515

1616
# Create environment
1717
env_num = 10
18+
1819
render_model = None
20+
21+
# ConvertObs can only be used for snakes_1v1, if you want to train snakes_3v3, you need to write your own wrapper
1922
env = make(
2023
"snakes_1v1",
2124
render_mode=render_model,
@@ -32,6 +35,7 @@ def train():
3235
agent = Agent(net)
3336
# Begin training
3437
agent.train(total_time_steps=100000)
38+
3539
env.close()
3640
agent.save("./selfplay_agent/")
3741
return agent
@@ -71,6 +75,7 @@ def evaluation():
7175
while not np.any(done):
7276
# predict next action based on the observation
7377
action, _ = agent.act(obs, info, deterministic=True)
78+
7479
obs, r, done, info = env.step(action)
7580
step += 1
7681

examples/snake/wrappers.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323

2424

2525
def raw2vec(raw_obs, n_player=2):
26-
control_index = raw_obs["controlled_snake_index"][0]
26+
control_index = raw_obs["controlled_snake_index"]
2727

28-
width = raw_obs["board_width"][0]
29-
height = raw_obs["board_height"][0]
30-
beans = raw_obs[1][0]
28+
width = raw_obs["board_width"]
29+
height = raw_obs["board_height"]
30+
beans = raw_obs[1]
3131

32-
ally_pos = raw_obs[control_index][0]
33-
enemy_pos = raw_obs[5 - control_index][0]
32+
ally_pos = raw_obs[control_index]
33+
enemy_pos = raw_obs[5 - control_index]
3434

3535
obs = np.zeros(width * height * n_player, dtype=int)
3636

@@ -59,7 +59,7 @@ def raw2vec(raw_obs, n_player=2):
5959
obs_ = np.array([])
6060
for i in obs:
6161
obs_ = np.concatenate([obs_, np.eye(6)[i]])
62-
obs_ = obs_.reshape(-1, width * height * n_player * 6)
62+
obs_ = obs_.reshape(width * height * n_player * 6)
6363

6464
return obs_
6565

@@ -87,4 +87,8 @@ def observation(self, observation):
8787
The flattened observation
8888
"""
8989

90-
return raw2vec(observation)
90+
new_obs = []
91+
for obs in observation:
92+
new_obs.append(raw2vec(obs))
93+
94+
return new_obs

openrl/algorithms/dqn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ def prepare_loss(
167167
)
168168

169169
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
170-
q_loss = torch.mean(
171-
F.mse_loss(q_values, q_targets.detach())
172-
) # 均方误差损失函数
170+
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
173171

174172
loss_list.append(q_loss)
175173

openrl/algorithms/vdn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,7 @@ def prepare_loss(
211211
rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
212212
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
213213
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
214-
q_loss = torch.mean(
215-
F.mse_loss(q_values, q_targets.detach())
216-
) # 均方误差损失函数
214+
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
217215

218216
loss_list.append(q_loss)
219217
return loss_list

openrl/arena/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def make_arena(
2828
render: Optional[bool] = False,
2929
**kwargs,
3030
):
31-
print(openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys())
3231
if custom_build_env is None:
3332
if (
3433
env_id in pettingzoo_all_envs

openrl/envs/PettingZoo/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from openrl.envs.common import build_envs
2222
from openrl.envs.PettingZoo.registration import pettingzoo_env_dict, register
23-
from openrl.envs.wrappers.pettingzoo_wrappers import SeedEnv
23+
from openrl.envs.wrappers.pettingzoo_wrappers import CheckAgentNumber, SeedEnv
2424

2525

2626
def PettingZoo_make(id, render_mode, disable_env_checker, **kwargs):
@@ -46,8 +46,9 @@ def make_PettingZoo_env(
4646
**kwargs,
4747
):
4848
env_num = 1
49-
env_wrappers = [SeedEnv]
49+
env_wrappers = [CheckAgentNumber, SeedEnv]
5050
env_wrappers += copy.copy(kwargs.pop("env_wrappers", []))
51+
5152
env_fns = build_envs(
5253
make=PettingZoo_make,
5354
id=id,
@@ -65,16 +66,15 @@ def make_PettingZoo_envs(
6566
render_mode: Optional[Union[str, List[str]]] = None,
6667
**kwargs,
6768
):
68-
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
69+
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,; Single2MultiAgentWrapper,
6970
MoveActionMask2InfoWrapper,
7071
RemoveTruncated,
71-
Single2MultiAgentWrapper,
7272
)
7373

74-
env_wrappers = [SeedEnv]
74+
env_wrappers = [CheckAgentNumber, SeedEnv]
7575
env_wrappers += copy.copy(kwargs.pop("opponent_wrappers", []))
7676
env_wrappers += [
77-
Single2MultiAgentWrapper,
77+
# Single2MultiAgentWrapper,
7878
RemoveTruncated,
7979
MoveActionMask2InfoWrapper,
8080
]

0 commit comments

Comments
 (0)