Skip to content

Commit 220919a

Browse files
authored
Merge pull request #225 from huangshiyu13/main
add gym_pybullet_drones env
2 parents 0978360 + e1e1a79 commit 220919a

10 files changed

Lines changed: 284 additions & 33 deletions

File tree

Gallery.md

Lines changed: 34 additions & 32 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Environments currently supported by OpenRL (for more details, please refer to [G
117117
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
118118
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
119119
- [Snake](http://www.jidiai.cn/env_detail?envid=1)
120+
- [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones)
120121
- [GridWorld](./examples/gridworld/)
121122
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
122123
- [Gym Retro](https://github.com/openai/retro)

README_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
9393
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
9494
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
9595
- [Snake](http://www.jidiai.cn/env_detail?envid=1)
96+
- [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones)
9697
- [GridWorld](./examples/gridworld/)
9798
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
9899
- [Gym Retro](https://github.com/openai/retro)

docs/images/drone.png

18.5 KB
Loading
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
### Installation
3+
4+
- Python >= 3.10
5+
- Fellow the installation instruction of [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones#installation).
6+
7+
### Train PPO
8+
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
episode_length: 500
2+
lr: 1e-3
3+
critic_lr: 1e-3
4+
gamma: 0.1
5+
ppo_epoch: 5
6+
use_valuenorm: true
7+
entropy_coef: 0.0
8+
hidden_size: 128
9+
layer_N: 4
10+
use_recurrent_policy: true
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
import time
19+
20+
import gym_pybullet_drones
21+
import gymnasium as gym
22+
import numpy as np
23+
24+
from openrl.envs.common import make
25+
26+
27+
def test_env():
28+
env = gym.make("hover-aviary-v0", gui=False, record=False)
29+
print("obs space:", env.observation_space)
30+
print("action space:", env.action_space)
31+
obs, info = env.reset(seed=42, options={})
32+
totoal_step = 0
33+
totol_reward = 0.0
34+
while True:
35+
obs, reward, done, truncated, info = env.step(env.action_space.sample())
36+
totoal_step += 1
37+
totol_reward += reward
38+
# env.render()
39+
# time.sleep(1)
40+
if done:
41+
break
42+
print("total step:", totoal_step)
43+
print("total reward:", totol_reward)
44+
45+
46+
def test_vec_env():
47+
env = make(
48+
"pybullet_drones/hover-aviary-v0",
49+
env_num=2,
50+
gui=False,
51+
record=False,
52+
asynchronous=True,
53+
)
54+
info, obs = env.reset(seed=0)
55+
totoal_step = 0
56+
totol_reward = 0.0
57+
while True:
58+
obs, reward, done, info = env.step(env.random_action())
59+
totoal_step += 1
60+
totol_reward += np.mean(reward)
61+
if np.any(done) or totoal_step > 100:
62+
break
63+
env.close()
64+
print("total step:", totoal_step)
65+
print("total reward:", totol_reward)
66+
67+
68+
if __name__ == "__main__":
69+
test_env()
70+
# test_vec_env()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
import torch
3+
4+
from openrl.configs.config import create_config_parser
5+
from openrl.envs.common import make
6+
from openrl.modules.common import PPONet as Net
7+
from openrl.runners.common import PPOAgent as Agent
8+
9+
env_name = "pybullet_drones/hover-aviary-v0"
10+
11+
12+
def train():
13+
# create the neural network
14+
cfg_parser = create_config_parser()
15+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
16+
17+
# create environment, set environment parallelism to 64
18+
env_num = 20
19+
# env_num = 1
20+
21+
env = make(
22+
env_name,
23+
env_num=env_num,
24+
cfg=cfg,
25+
asynchronous=True,
26+
env_wrappers=[],
27+
gui=False,
28+
)
29+
30+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
31+
# initialize the trainer
32+
agent = Agent(
33+
net,
34+
)
35+
# start training, set total number of training steps to 100000
36+
agent.train(total_time_steps=1000000)
37+
38+
agent.save("./ppo_agent")
39+
env.close()
40+
return agent
41+
42+
43+
def evaluation():
44+
cfg_parser = create_config_parser()
45+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
46+
# begin to test
47+
# Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.
48+
49+
env = make(
50+
env_name,
51+
env_num=1,
52+
asynchronous=False,
53+
env_wrappers=[],
54+
cfg=cfg,
55+
gui=False,
56+
record=False,
57+
)
58+
59+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
60+
# initialize the trainer
61+
agent = Agent(
62+
net,
63+
)
64+
agent.load("./ppo_agent")
65+
66+
# The trained agent sets up the interactive environment it needs.
67+
agent.set_env(env)
68+
# Initialize the environment and get initial observations and environmental information.
69+
obs, info = env.reset()
70+
done = False
71+
step = 0
72+
total_reward = 0.0
73+
while not np.any(done):
74+
# Based on environmental observation input, predict next action.
75+
action, _ = agent.act(obs, deterministic=True)
76+
print("action:", action)
77+
obs, r, done, info = env.step(action)
78+
step += 1
79+
total_reward += np.mean(r)
80+
# if step % 50 == 0:
81+
# print(f"{step}: reward:{np.mean(r)}")
82+
print("total step:", step)
83+
print("total reward:", total_reward)
84+
env.close()
85+
86+
87+
if __name__ == "__main__":
88+
# train()
89+
evaluation()

openrl/envs/common/registration.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,14 @@ def make(
6565
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
6666
)
6767
else:
68-
if id.startswith("snakes_"):
68+
if id.startswith("pybullet_drones/"):
69+
from openrl.envs.gym_pybullet_drones import make_single_agent_drone_envs
70+
71+
env_fns = make_single_agent_drone_envs(
72+
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
73+
)
74+
75+
elif id.startswith("snakes_"):
6976
from openrl.envs.snake import make_snake_envs
7077

7178
env_fns = make_snake_envs(
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
import copy
19+
from typing import Callable, List, Optional, Union
20+
21+
import gymnasium as gym
22+
from gymnasium import Env
23+
24+
from openrl.envs.common import build_envs
25+
26+
27+
def make_single_agent_drone_env(id: str, render_mode, disable_env_checker, **kwargs):
28+
import gym_pybullet_drones
29+
30+
prefix = "pybullet_drones/"
31+
assert id.startswith(prefix), "id must start with pybullet_drones/"
32+
kwargs.pop("cfg")
33+
34+
env = gym.envs.registration.make(id[len(prefix) :], **kwargs)
35+
return env
36+
37+
38+
def make_single_agent_drone_envs(
39+
id: str,
40+
env_num: int = 1,
41+
render_mode: Optional[Union[str, List[str]]] = None,
42+
**kwargs,
43+
) -> List[Callable[[], Env]]:
44+
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
45+
RemoveTruncated,
46+
Single2MultiAgentWrapper,
47+
)
48+
49+
env_wrappers = copy.copy(kwargs.pop("env_wrappers", []))
50+
env_wrappers += [
51+
Single2MultiAgentWrapper,
52+
RemoveTruncated,
53+
]
54+
55+
env_fns = build_envs(
56+
make=make_single_agent_drone_env,
57+
id=id,
58+
env_num=env_num,
59+
render_mode=render_mode,
60+
wrappers=env_wrappers,
61+
**kwargs,
62+
)
63+
return env_fns

0 commit comments

Comments
 (0)