Skip to content

Commit 0228e51

Browse files
committed
fix arena petting zoo import error
format
1 parent 8889302 commit 0228e51

8 files changed

Lines changed: 32 additions & 42 deletions

File tree

openrl/algorithms/dqn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def prepare_loss(
167167
)
168168

169169
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
170-
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
170+
q_loss = torch.mean(
171+
F.mse_loss(q_values, q_targets.detach())
172+
) # 均方误差损失函数
171173

172174
loss_list.append(q_loss)
173175

openrl/algorithms/vdn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def prepare_loss(
211211
rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
212212
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
213213
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
214-
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
214+
q_loss = torch.mean(
215+
F.mse_loss(q_values, q_targets.detach())
216+
) # 均方误差损失函数
215217

216218
loss_list.append(q_loss)
217219
return loss_list

openrl/arena/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ def make_arena(
3030
**kwargs,
3131
):
3232
if custom_build_env is None:
33+
from openrl.envs import PettingZoo
34+
3335
if (
3436
env_id in pettingzoo_all_envs
35-
or env_id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys()
37+
or env_id in PettingZoo.registration.pettingzoo_env_dict.keys()
3638
):
3739
from openrl.envs.PettingZoo import make_PettingZoo_env
3840

openrl/envs/mpe/rendering.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@
3131
except ImportError:
3232
print(
3333
"Error occured while running `from pyglet.gl import *`",
34-
(
35-
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
36-
" install python-opengl'. If you're running on a server, you may need a"
37-
" virtual frame buffer; something like this should work: 'xvfb-run -s"
38-
' "-screen 0 1400x900x24" python <your_script.py>\''
39-
),
34+
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
35+
" install python-opengl'. If you're running on a server, you may need a"
36+
" virtual frame buffer; something like this should work: 'xvfb-run -s"
37+
' "-screen 0 1400x900x24" python <your_script.py>\'',
4038
)
4139

4240
import math

openrl/envs/snake/snake.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,9 @@ class Snake:
674674
def __init__(self, player_id, board_width, board_height, init_len):
675675
self.actions = [-2, 2, -1, 1]
676676
self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"}
677-
self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
677+
self.direction = random.choice(
678+
self.actions
679+
) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
678680
self.board_width = board_width
679681
self.board_height = board_height
680682
x = random.randrange(0, board_height)

openrl/envs/vec_env/async_venv.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,8 @@ def reset_send(
234234

235235
if self._state != AsyncState.DEFAULT:
236236
raise AlreadyPendingCallError(
237-
(
238-
"Calling `reset_send` while waiting for a pending call to"
239-
f" `{self._state.value}` to complete"
240-
),
237+
"Calling `reset_send` while waiting for a pending call to"
238+
f" `{self._state.value}` to complete",
241239
self._state.value,
242240
)
243241

@@ -329,10 +327,8 @@ def step_send(self, actions: np.ndarray):
329327
self._assert_is_running()
330328
if self._state != AsyncState.DEFAULT:
331329
raise AlreadyPendingCallError(
332-
(
333-
"Calling `step_send` while waiting for a pending call to"
334-
f" `{self._state.value}` to complete."
335-
),
330+
"Calling `step_send` while waiting for a pending call to"
331+
f" `{self._state.value}` to complete.",
336332
self._state.value,
337333
)
338334

@@ -342,9 +338,7 @@ def step_send(self, actions: np.ndarray):
342338
pipe.send(("step", action))
343339
self._state = AsyncState.WAITING_STEP
344340

345-
def step_fetch(
346-
self, timeout: Optional[Union[int, float]] = None
347-
) -> Union[
341+
def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[
348342
Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
349343
Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
350344
]:
@@ -576,10 +570,8 @@ def call_send(self, name: str, *args, **kwargs):
576570
self._assert_is_running()
577571
if self._state != AsyncState.DEFAULT:
578572
raise AlreadyPendingCallError(
579-
(
580-
"Calling `call_send` while waiting "
581-
f"for a pending call to `{self._state.value}` to complete."
582-
),
573+
"Calling `call_send` while waiting "
574+
f"for a pending call to `{self._state.value}` to complete.",
583575
str(self._state.value),
584576
)
585577

@@ -636,10 +628,8 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs):
636628
self._assert_is_running()
637629
if self._state != AsyncState.DEFAULT:
638630
raise AlreadyPendingCallError(
639-
(
640-
"Calling `exec_func_send` while waiting "
641-
f"for a pending call to `{self._state.value}` to complete."
642-
),
631+
"Calling `exec_func_send` while waiting "
632+
f"for a pending call to `{self._state.value}` to complete.",
643633
str(self._state.value),
644634
)
645635

@@ -717,10 +707,8 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]):
717707

718708
if self._state != AsyncState.DEFAULT:
719709
raise AlreadyPendingCallError(
720-
(
721-
"Calling `set_attr` while waiting "
722-
f"for a pending call to `{self._state.value}` to complete."
723-
),
710+
"Calling `set_attr` while waiting "
711+
f"for a pending call to `{self._state.value}` to complete.",
724712
str(self._state.value),
725713
)
726714

openrl/utils/callbacks/checkpoint_callback.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st
7272
"""
7373
return os.path.join(
7474
self.save_path,
75-
(
76-
f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}"
77-
),
75+
f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}",
7876
)
7977

8078
def _on_step(self) -> bool:

openrl/utils/evaluation.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,10 @@ def evaluate_policy(
6868

6969
if not is_monitor_wrapped and warn:
7070
warnings.warn(
71-
(
72-
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
73-
" may result in reporting modified episode lengths and rewards, if"
74-
" other wrappers happen to modify these. Consider wrapping environment"
75-
" first with ``Monitor`` wrapper."
76-
),
71+
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
72+
" may result in reporting modified episode lengths and rewards, if"
73+
" other wrappers happen to modify these. Consider wrapping environment"
74+
" first with ``Monitor`` wrapper.",
7775
UserWarning,
7876
)
7977

0 commit comments

Comments
 (0)