Skip to content

Commit d2ff6b3

Browse files
committed
update _worker test
1 parent 5aa7d9e commit d2ff6b3

2 files changed

Lines changed: 14 additions & 17 deletions

File tree

openrl/envs/vec_env/async_venv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,15 +757,14 @@ def prepare_obs(observation):
757757
)
758758
observation = None
759759
return observation
760+
760761
if parent_pipe is not None:
761762
parent_pipe.close()
762763
try:
763764
while True:
764765
command, data = pipe.recv()
765766
print(command)
766767

767-
768-
769768
if command == "reset":
770769
result = env.reset(**data)
771770

tests/test_env/test_vec_env/test_async_env.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@
1616

1717
""""""
1818

19+
import multiprocessing as mp
1920
import os
2021
import sys
2122

2223
import pytest
2324
from gymnasium.wrappers import EnvCompatibility
2425

2526
from openrl.envs.toy_envs import make_toy_envs
26-
from openrl.envs.vec_env.async_venv import AsyncVectorEnv
27-
from openrl.envs.vec_env.async_venv import _worker
28-
import multiprocessing as mp
27+
from openrl.envs.vec_env.async_venv import AsyncVectorEnv, _worker
28+
2929

3030
class CustomEnvCompatibility(EnvCompatibility):
3131
def reset(self, **kwargs):
@@ -62,37 +62,40 @@ def assert_env_name(env, env_name):
6262
# env.set_attr("metadata", {"name": env_name_new})
6363
# env.exec_func(assert_env_name, indices=None, env_name=env_name_new)
6464

65-
def main_control(parent_pipe,child_pipe):
65+
66+
def main_control(parent_pipe, child_pipe):
6667
child_pipe.close()
6768

68-
parent_pipe.send(("reset", {"seed":0}))
69+
parent_pipe.send(("reset", {"seed": 0}))
6970
result, success = parent_pipe.recv()
7071
assert success, result
7172

7273
parent_pipe.send(("step", [0]))
7374
result, success = parent_pipe.recv()
7475
assert success, result
7576

76-
parent_pipe.send(("_call", ("render",[],{})))
77+
parent_pipe.send(("_call", ("render", [], {})))
7778
result, success = parent_pipe.recv()
7879
assert success, result
7980

8081
parent_pipe.send(("_setattr", ("metadata", {"name": "IdentityEnvNew"})))
8182
result, success = parent_pipe.recv()
8283
assert success, result
8384

84-
parent_pipe.send(("_func_exec",(assert_env_name,None,[],{"env_name":"IdentityEnvNew"})))
85+
parent_pipe.send(
86+
("_func_exec", (assert_env_name, None, [], {"env_name": "IdentityEnvNew"}))
87+
)
8588
result, success = parent_pipe.recv()
8689
assert success, result
8790

88-
parent_pipe.send(("close",None))
91+
parent_pipe.send(("close", None))
8992
result, success = parent_pipe.recv()
9093
assert success, result
9194

9295

9396
@pytest.mark.unittest
9497
def test_worker():
95-
for auto_reset in [True,False]:
98+
for auto_reset in [True, False]:
9699
ctx = mp.get_context(None)
97100
parent_pipe, child_pipe = ctx.Pipe()
98101

@@ -101,17 +104,12 @@ def test_worker():
101104
process = ctx.Process(
102105
target=main_control,
103106
name="test",
104-
args=(
105-
parent_pipe,
106-
child_pipe
107-
),
107+
args=(parent_pipe, child_pipe),
108108
)
109109
process.daemon = True
110110
process.start()
111111
_worker(0, init_envs()[0], child_pipe, None, False, error_queue, auto_reset)
112112

113113

114-
115-
116114
if __name__ == "__main__":
117115
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 commit comments

Comments
 (0)