Skip to content

Commit d7fdb4e

Browse files
committed
add gym_pybullet_drones env
1 parent 7c5b9c8 commit d7fdb4e

10 files changed

Lines changed: 278 additions & 33 deletions

File tree

Gallery.md

Lines changed: 34 additions & 32 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Environments currently supported by OpenRL (for more details, please refer to [G
117117
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
118118
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
119119
- [Snake](http://www.jidiai.cn/env_detail?envid=1)
120+
- [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones)
120121
- [GridWorld](./examples/gridworld/)
121122
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
122123
- [Gym Retro](https://github.com/openai/retro)

README_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
9393
- [Omniverse Isaac Gym](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs)
9494
- [DeepMind Control](https://shimmy.farama.org/environments/dm_control/)
9595
- [Snake](http://www.jidiai.cn/env_detail?envid=1)
96+
- [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones)
9697
- [GridWorld](./examples/gridworld/)
9798
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
9899
- [Gym Retro](https://github.com/openai/retro)

docs/images/drone.png

18.5 KB
Loading
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
### Installation
3+
4+
- Python >= 3.10
5+
- Fellow the installation instruction of [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones#installation).
6+
7+
### Train PPO
8+
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
episode_length: 500
2+
lr: 1e-3
3+
critic_lr: 1e-3
4+
gamma: 0.1
5+
ppo_epoch: 5
6+
use_valuenorm: true
7+
entropy_coef: 0.0
8+
hidden_size: 128
9+
layer_N: 4
10+
use_recurrent_policy: true
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
import time
19+
20+
import numpy as np
21+
import gym_pybullet_drones
22+
import gymnasium as gym
23+
24+
from openrl.envs.common import make
25+
def test_env():
26+
env = gym.make("hover-aviary-v0",gui=False,record=False)
27+
print("obs space:",env.observation_space)
28+
print("action space:",env.action_space)
29+
obs, info = env.reset(seed=42, options={})
30+
totoal_step =0
31+
totol_reward = 0.
32+
while True:
33+
obs, reward, done, truncated, info = env.step(env.action_space.sample())
34+
totoal_step+=1
35+
totol_reward+=reward
36+
# env.render()
37+
# time.sleep(1)
38+
if done:
39+
break
40+
print("total step:",totoal_step)
41+
print("total reward:",totol_reward)
42+
43+
def test_vec_env():
44+
env = make("pybullet_drones/hover-aviary-v0",env_num=2,gui=False,record=False,asynchronous=True)
45+
info,obs = env.reset(seed=0)
46+
totoal_step = 0
47+
totol_reward = 0.
48+
while True:
49+
obs, reward, done, info = env.step(env.random_action())
50+
totoal_step+=1
51+
totol_reward+=np.mean(reward)
52+
if np.any(done) or totoal_step>100:
53+
break
54+
env.close()
55+
print("total step:", totoal_step)
56+
print("total reward:", totol_reward)
57+
58+
59+
60+
61+
if __name__ == '__main__':
62+
test_env()
63+
# test_vec_env()
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
import torch
3+
4+
from openrl.configs.config import create_config_parser
5+
from openrl.envs.common import make
6+
from openrl.modules.common import PPONet as Net
7+
from openrl.runners.common import PPOAgent as Agent
8+
9+
env_name = "pybullet_drones/hover-aviary-v0"
10+
11+
12+
def train():
13+
# create the neural network
14+
cfg_parser = create_config_parser()
15+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
16+
17+
# create environment, set environment parallelism to 64
18+
env_num = 20
19+
# env_num = 1
20+
21+
env = make(
22+
env_name,
23+
env_num=env_num,
24+
cfg=cfg,
25+
asynchronous=True,
26+
env_wrappers=[],
27+
gui=False,
28+
)
29+
30+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
31+
# initialize the trainer
32+
agent = Agent(
33+
net,
34+
)
35+
# start training, set total number of training steps to 100000
36+
agent.train(total_time_steps=1000000)
37+
38+
agent.save("./ppo_agent")
39+
env.close()
40+
return agent
41+
42+
43+
def evaluation():
44+
cfg_parser = create_config_parser()
45+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
46+
# begin to test
47+
# Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.
48+
49+
env = make(
50+
env_name,
51+
env_num=1,
52+
asynchronous=False,
53+
env_wrappers=[],
54+
cfg=cfg,
55+
gui=False,
56+
record=False,
57+
)
58+
59+
60+
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
61+
# initialize the trainer
62+
agent = Agent(
63+
net,
64+
)
65+
agent.load("./ppo_agent")
66+
67+
# The trained agent sets up the interactive environment it needs.
68+
agent.set_env(env)
69+
# Initialize the environment and get initial observations and environmental information.
70+
obs, info = env.reset()
71+
done = False
72+
step = 0
73+
total_reward = 0.0
74+
while not np.any(done):
75+
# Based on environmental observation input, predict next action.
76+
action, _ = agent.act(obs, deterministic=True)
77+
print("action:",action)
78+
obs, r, done, info = env.step(action)
79+
step += 1
80+
total_reward += np.mean(r)
81+
# if step % 50 == 0:
82+
# print(f"{step}: reward:{np.mean(r)}")
83+
print("total step:", step)
84+
print("total reward:", total_reward)
85+
env.close()
86+
87+
88+
if __name__ == "__main__":
89+
# train()
90+
evaluation()

openrl/envs/common/registration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ def make(
6565
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
6666
)
6767
else:
68-
if id.startswith("snakes_"):
68+
if id.startswith("pybullet_drones/"):
69+
70+
from openrl.envs.gym_pybullet_drones import make_single_agent_drone_envs
71+
env_fns = make_single_agent_drone_envs(id=id,env_num=env_num,render_mode=convert_render_mode,**kwargs)
72+
73+
elif id.startswith("snakes_"):
6974
from openrl.envs.snake import make_snake_envs
7075

7176
env_fns = make_snake_envs(
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
import copy
19+
from typing import Callable, List, Optional, Union
20+
21+
import gymnasium as gym
22+
from gymnasium import Env
23+
24+
from openrl.envs.common import build_envs
25+
26+
27+
def make_single_agent_drone_env(id:str,
28+
render_mode,
29+
disable_env_checker,
30+
**kwargs):
31+
import gym_pybullet_drones
32+
prefix = "pybullet_drones/"
33+
assert id.startswith(prefix), "id must start with pybullet_drones/"
34+
kwargs.pop("cfg")
35+
36+
env = gym.envs.registration.make(id[len(prefix):],**kwargs)
37+
return env
38+
39+
40+
def make_single_agent_drone_envs(
41+
id: str,
42+
env_num: int = 1,
43+
render_mode: Optional[Union[str, List[str]]] = None,
44+
**kwargs,
45+
) -> List[Callable[[], Env]]:
46+
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
47+
RemoveTruncated,
48+
Single2MultiAgentWrapper,
49+
)
50+
51+
env_wrappers = copy.copy(kwargs.pop("env_wrappers", []))
52+
env_wrappers += [
53+
Single2MultiAgentWrapper,
54+
RemoveTruncated,
55+
]
56+
57+
env_fns = build_envs(
58+
make=make_single_agent_drone_env,
59+
id=id,
60+
env_num=env_num,
61+
render_mode=render_mode,
62+
wrappers=env_wrappers,
63+
**kwargs,
64+
)
65+
return env_fns

0 commit comments

Comments
 (0)