Skip to content

Commit 5aa7d9e

Browse files
committed
update _worker test
1 parent c7f8f3a commit 5aa7d9e

2 files changed

Lines changed: 69 additions & 15 deletions

File tree

openrl/envs/vec_env/async_venv.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def _worker(
734734
index: int,
735735
env_fn: callable,
736736
pipe: Connection,
737-
parent_pipe: Connection,
737+
parent_pipe: Optional[Connection],
738738
shared_memory: bool,
739739
error_queue: Queue,
740740
auto_reset: bool = True,
@@ -757,11 +757,14 @@ def prepare_obs(observation):
757757
)
758758
observation = None
759759
return observation
760-
761-
parent_pipe.close()
760+
if parent_pipe is not None:
761+
parent_pipe.close()
762762
try:
763763
while True:
764764
command, data = pipe.recv()
765+
print(command)
766+
767+
765768

766769
if command == "reset":
767770
result = env.reset(**data)

tests/test_env/test_vec_env/test_async_env.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
from openrl.envs.toy_envs import make_toy_envs
2626
from openrl.envs.vec_env.async_venv import AsyncVectorEnv
27-
27+
from openrl.envs.vec_env.async_venv import _worker
28+
import multiprocessing as mp
2829

2930
class CustomEnvCompatibility(EnvCompatibility):
3031
def reset(self, **kwargs):
@@ -48,18 +49,68 @@ def assert_env_name(env, env_name):
4849
assert env.metadata["name"].__name__ == env_name
4950

5051

52+
# @pytest.mark.unittest
53+
# def test_async_env():
54+
# env_name = "IdentityEnv"
55+
# env = AsyncVectorEnv(init_envs(), shared_memory=True)
56+
# assert (
57+
# env._env_name == env_name
58+
# ), "AsyncVectorEnv should have the same metadata as the wrapped env"
59+
# env.exec_func(assert_env_name, indices=None, env_name=env_name)
60+
# env.call("render")
61+
# env_name_new = "IdentityEnvNew"
62+
# env.set_attr("metadata", {"name": env_name_new})
63+
# env.exec_func(assert_env_name, indices=None, env_name=env_name_new)
64+
65+
def main_control(parent_pipe,child_pipe):
66+
child_pipe.close()
67+
68+
parent_pipe.send(("reset", {"seed":0}))
69+
result, success = parent_pipe.recv()
70+
assert success, result
71+
72+
parent_pipe.send(("step", [0]))
73+
result, success = parent_pipe.recv()
74+
assert success, result
75+
76+
parent_pipe.send(("_call", ("render",[],{})))
77+
result, success = parent_pipe.recv()
78+
assert success, result
79+
80+
parent_pipe.send(("_setattr", ("metadata", {"name": "IdentityEnvNew"})))
81+
result, success = parent_pipe.recv()
82+
assert success, result
83+
84+
parent_pipe.send(("_func_exec",(assert_env_name,None,[],{"env_name":"IdentityEnvNew"})))
85+
result, success = parent_pipe.recv()
86+
assert success, result
87+
88+
parent_pipe.send(("close",None))
89+
result, success = parent_pipe.recv()
90+
assert success, result
91+
92+
5193
@pytest.mark.unittest
52-
def test_async_env():
53-
env_name = "IdentityEnv"
54-
env = AsyncVectorEnv(init_envs(), shared_memory=True)
55-
assert (
56-
env._env_name == env_name
57-
), "AsyncVectorEnv should have the same metadata as the wrapped env"
58-
env.exec_func(assert_env_name, indices=None, env_name=env_name)
59-
env.call("render")
60-
env_name_new = "IdentityEnvNew"
61-
env.set_attr("metadata", {"name": env_name_new})
62-
env.exec_func(assert_env_name, indices=None, env_name=env_name_new)
94+
def test_worker():
95+
for auto_reset in [True,False]:
96+
ctx = mp.get_context(None)
97+
parent_pipe, child_pipe = ctx.Pipe()
98+
99+
error_queue = ctx.Queue()
100+
101+
process = ctx.Process(
102+
target=main_control,
103+
name="test",
104+
args=(
105+
parent_pipe,
106+
child_pipe
107+
),
108+
)
109+
process.daemon = True
110+
process.start()
111+
_worker(0, init_envs()[0], child_pipe, None, False, error_queue, auto_reset)
112+
113+
63114

64115

65116
if __name__ == "__main__":

0 commit comments

Comments
 (0)