Skip to content

Commit e1e1a79

Browse files
committed
add gym_pybullet_drones env
1 parent d7fdb4e commit e1e1a79

4 files changed

Lines changed: 40 additions & 34 deletions

File tree

examples/gym_pybullet_drones/test_env.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,54 @@
1717
""""""
1818
import time
1919

20-
import numpy as np
2120
import gym_pybullet_drones
2221
import gymnasium as gym
22+
import numpy as np
2323

2424
from openrl.envs.common import make
25+
26+
2527
def test_env():
26-
env = gym.make("hover-aviary-v0",gui=False,record=False)
27-
print("obs space:",env.observation_space)
28-
print("action space:",env.action_space)
28+
env = gym.make("hover-aviary-v0", gui=False, record=False)
29+
print("obs space:", env.observation_space)
30+
print("action space:", env.action_space)
2931
obs, info = env.reset(seed=42, options={})
30-
totoal_step =0
31-
totol_reward = 0.
32+
totoal_step = 0
33+
totol_reward = 0.0
3234
while True:
3335
obs, reward, done, truncated, info = env.step(env.action_space.sample())
34-
totoal_step+=1
35-
totol_reward+=reward
36+
totoal_step += 1
37+
totol_reward += reward
3638
# env.render()
3739
# time.sleep(1)
3840
if done:
3941
break
40-
print("total step:",totoal_step)
41-
print("total reward:",totol_reward)
42+
print("total step:", totoal_step)
43+
print("total reward:", totol_reward)
44+
4245

4346
def test_vec_env():
44-
env = make("pybullet_drones/hover-aviary-v0",env_num=2,gui=False,record=False,asynchronous=True)
45-
info,obs = env.reset(seed=0)
47+
env = make(
48+
"pybullet_drones/hover-aviary-v0",
49+
env_num=2,
50+
gui=False,
51+
record=False,
52+
asynchronous=True,
53+
)
54+
info, obs = env.reset(seed=0)
4655
totoal_step = 0
47-
totol_reward = 0.
56+
totol_reward = 0.0
4857
while True:
49-
obs, reward, done, info = env.step(env.random_action())
50-
totoal_step+=1
51-
totol_reward+=np.mean(reward)
52-
if np.any(done) or totoal_step>100:
58+
obs, reward, done, info = env.step(env.random_action())
59+
totoal_step += 1
60+
totol_reward += np.mean(reward)
61+
if np.any(done) or totoal_step > 100:
5362
break
5463
env.close()
5564
print("total step:", totoal_step)
5665
print("total reward:", totol_reward)
5766

5867

59-
60-
61-
if __name__ == '__main__':
68+
if __name__ == "__main__":
6269
test_env()
63-
# test_vec_env()
70+
# test_vec_env()

examples/gym_pybullet_drones/train_ppo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def evaluation():
5656
record=False,
5757
)
5858

59-
6059
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
6160
# initialize the trainer
6261
agent = Agent(
@@ -74,7 +73,7 @@ def evaluation():
7473
while not np.any(done):
7574
# Based on environmental observation input, predict next action.
7675
action, _ = agent.act(obs, deterministic=True)
77-
print("action:",action)
76+
print("action:", action)
7877
obs, r, done, info = env.step(action)
7978
step += 1
8079
total_reward += np.mean(r)

openrl/envs/common/registration.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ def make(
6666
)
6767
else:
6868
if id.startswith("pybullet_drones/"):
69-
7069
from openrl.envs.gym_pybullet_drones import make_single_agent_drone_envs
71-
env_fns = make_single_agent_drone_envs(id=id,env_num=env_num,render_mode=convert_render_mode,**kwargs)
70+
71+
env_fns = make_single_agent_drone_envs(
72+
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
73+
)
7274

7375
elif id.startswith("snakes_"):
7476
from openrl.envs.snake import make_snake_envs

openrl/envs/gym_pybullet_drones/__init__.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,22 @@
2424
from openrl.envs.common import build_envs
2525

2626

27-
def make_single_agent_drone_env(id:str,
28-
render_mode,
29-
disable_env_checker,
30-
**kwargs):
27+
def make_single_agent_drone_env(id: str, render_mode, disable_env_checker, **kwargs):
3128
import gym_pybullet_drones
29+
3230
prefix = "pybullet_drones/"
3331
assert id.startswith(prefix), "id must start with pybullet_drones/"
3432
kwargs.pop("cfg")
3533

36-
env = gym.envs.registration.make(id[len(prefix):],**kwargs)
34+
env = gym.envs.registration.make(id[len(prefix) :], **kwargs)
3735
return env
3836

3937

4038
def make_single_agent_drone_envs(
41-
id: str,
42-
env_num: int = 1,
43-
render_mode: Optional[Union[str, List[str]]] = None,
44-
**kwargs,
39+
id: str,
40+
env_num: int = 1,
41+
render_mode: Optional[Union[str, List[str]]] = None,
42+
**kwargs,
4543
) -> List[Callable[[], Env]]:
4644
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
4745
RemoveTruncated,

0 commit comments

Comments
 (0)