Skip to content

Commit a8d1db1

Browse files
authored
simplify agent proxy setup (#165)
This was getting a tiny bit complex and I really want to stop hardcodinginput/output names here. I've tried cloud training with MLP agents -- I'd assume the features+images is equally changed as we're just reusing more code.
1 parent 9f2cc4c commit a8d1db1

3 files changed

Lines changed: 170 additions & 93 deletions

File tree

emote/proxies.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ def policy(self) -> nn.Module:
2525
pass
2626

2727
@property
28-
def input_names(self) -> tuple[str]:
28+
def input_names(self) -> tuple[str, ...]:
29+
...
30+
31+
@property
32+
def output_names(self) -> tuple[str, ...]:
2933
...
3034

3135

emote/sac.py

Lines changed: 111 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from emote.mixins.logging import LoggingMixin
1515
from emote.proxies import AgentProxy
1616
from emote.typing import AgentId, DictObservation, DictResponse, EpisodeState
17+
from emote.utils.deprecated import deprecated
1718
from emote.utils.gamma_matrix import discount, make_gamma_matrix, split_rollouts
1819
from emote.utils.spaces import MDPSpace
1920

@@ -348,110 +349,45 @@ def __call__(self, *args, **kwargs):
348349
def input_names(self):
349350
return self._inner.input_names
350351

351-
352-
class FeatureAgentProxy:
353-
"""An agent proxy for basic MLPs.
354-
355-
This AgentProxy assumes that the observations will contain a single flat array of features.
356-
"""
357-
358-
def __init__(self, policy: nn.Module, device: torch.device, input_key: str = "obs"):
359-
"""Create a new proxy.
360-
361-
:param policy: The policy to execute for actions.
362-
:param device: The device to run on.
363-
:param input_key: The name of the features. (default: "obs")
364-
"""
365-
self.policy = policy
366-
self._end_states = [EpisodeState.TERMINAL, EpisodeState.INTERRUPTED]
367-
self.device = device
368-
369-
self._input_key = input_key
370-
371-
def __call__(
372-
self,
373-
observations: Dict[AgentId, DictObservation],
374-
) -> Dict[AgentId, DictResponse]:
375-
"""Runs the policy and returns the actions."""
376-
# The network takes observations of size batch x obs for each observation space.
377-
assert len(observations) > 0, "Observations must not be empty."
378-
active_agents = [
379-
agent_id
380-
for agent_id, obs in observations.items()
381-
if obs.episode_state not in self._end_states
382-
]
383-
tensor_obs = torch.tensor(
384-
np.array(
385-
[
386-
observations[agent_id].array_data[self._input_key]
387-
for agent_id in active_agents
388-
]
389-
)
390-
).to(self.device)
391-
392-
actions = self.policy(tensor_obs)[0].detach().cpu().numpy()
393-
394-
return {
395-
agent_id: DictResponse(list_data={"actions": actions[i]}, scalar_data={})
396-
for i, agent_id in enumerate(active_agents)
397-
}
398-
399352
@property
400-
def input_names(self):
401-
return (self._input_key,)
402-
403-
404-
class VisionAgentProxy:
405-
"""This AgentProxy assumes that the observations will contain image observations 'obs'"""
353+
def output_names(self):
354+
return self._inner.output_names
406355

407-
def __init__(self, policy: nn.Module, device: torch.device):
408-
self.policy = policy
409-
self._end_states = [EpisodeState.TERMINAL, EpisodeState.INTERRUPTED]
410-
self.device = device
356+
@property
357+
def policy(self):
358+
return self._inner.policy
411359

412-
def __call__(
413-
self, observations: Dict[AgentId, DictObservation]
414-
) -> Dict[AgentId, DictResponse]:
415-
"""Runs the policy and returns the actions."""
416-
# The network takes observations of size batch x obs for each observation space.
417-
assert len(observations) > 0, "Observations must not be empty."
418-
active_agents = [
419-
agent_id
420-
for agent_id, obs in observations.items()
421-
if obs.episode_state not in self._end_states
422-
]
423-
np_obs = np.array(
424-
[observations[agent_id].array_data["obs"] for agent_id in active_agents]
425-
)
426-
tensor_obs = torch.tensor(np_obs).to(self.device)
427-
actions = self.policy(tensor_obs)[0].detach().cpu().numpy()
428-
return {
429-
agent_id: DictResponse(list_data={"actions": actions[i]}, scalar_data={})
430-
for i, agent_id in enumerate(active_agents)
431-
}
432360

361+
class GenericAgentProxy(AgentProxy):
362+
"""Observations are dicts that contain multiple input and output keys.
433363
434-
class MultiKeyAgentProxy:
435-
"""Observations are dicts that contain multiple input keys (e.g. both "features" and "images")"""
364+
For example, we might have a policy that takes in both "obs" and
365+
"goal" and outputs "actions". In order to be able to properly
366+
invoke the network it is the responsibility of this proxy to
367+
collate the inputs and decollate the outputs per agent.
368+
"""
436369

437370
def __init__(
438371
self,
439372
policy: nn.Module,
440373
device: torch.device,
441374
input_keys: tuple,
442-
spaces: MDPSpace = None,
375+
output_keys: tuple,
376+
spaces: MDPSpace | None = None,
443377
):
444378
"""Create a new proxy.
445379
446-
Args:
447-
policy (nn.Module): The policy to execute for actions.
448-
device (torch.device): The device to run on.
449-
input_keys (tuple): The names of the input.
380+
:param policy (nn.Module): The policy to invoke
381+
:param device (torch.device): The device to run on
382+
:param input_keys (tuple): The names of the inputs to the policy
383+
:param output_keys (tuple): The names of the outputs of the policy
384+
:param spaces (MDPSpace): The spaces of the inputs and outputs
450385
"""
451-
self.policy = policy
386+
self._policy = policy
452387
self._end_states = [EpisodeState.TERMINAL, EpisodeState.INTERRUPTED]
453388
self.device = device
454389
self.input_keys = input_keys
390+
self.output_keys = output_keys
455391
self._spaces = spaces
456392

457393
def __call__(
@@ -467,7 +403,7 @@ def __call__(
467403
if obs.episode_state not in self._end_states
468404
]
469405

470-
dict_tensor_obs = {}
406+
tensor_obs_list = [None] * len(self.input_keys)
471407
for input_key in self.input_keys:
472408
np_obs = np.array(
473409
[
@@ -482,15 +418,98 @@ def __call__(
482418
np_obs = np.reshape(np_obs, shape)
483419

484420
tensor_obs = torch.tensor(np_obs).to(self.device)
485-
dict_tensor_obs[input_key] = tensor_obs
421+
index = self.input_keys.index(input_key)
422+
tensor_obs_list[index] = tensor_obs
486423

487-
actions = self.policy(**dict_tensor_obs)[0].detach().cpu().numpy()
424+
outputs: tuple[any, ...] = self._policy(*tensor_obs_list)
425+
# we remove element 1 as we don't need the logprobs here
426+
outputs = outputs[0:1] + outputs[2:]
488427

489-
return {
490-
agent_id: DictResponse(list_data={"actions": actions[i]}, scalar_data={})
491-
for i, agent_id in enumerate(active_agents)
428+
outputs = {
429+
key: outputs[i].detach().cpu().numpy()
430+
for i, key in enumerate(self.output_keys)
492431
}
493432

433+
agent_data = [
434+
(agent_id, DictResponse(list_data={}, scalar_data={}))
435+
for agent_id in active_agents
436+
]
437+
438+
for i, (_, response) in enumerate(agent_data):
439+
for k, data in outputs.items():
440+
response.list_data[k] = data[i]
441+
442+
return dict(agent_data)
443+
494444
@property
495445
def input_names(self):
496446
return self.input_keys
447+
448+
@property
449+
def output_names(self):
450+
return self.output_keys
451+
452+
@property
453+
def policy(self):
454+
return self._policy
455+
456+
457+
class FeatureAgentProxy(GenericAgentProxy):
458+
"""An agent proxy for basic MLPs.
459+
460+
This AgentProxy assumes that the observations will contain a single flat array of features.
461+
"""
462+
463+
@deprecated(reason="Use GenericAgentProxy instead", version="23.1.0")
464+
def __init__(self, policy: nn.Module, device: torch.device, input_key: str = "obs"):
465+
"""Create a new proxy.
466+
467+
:param policy: The policy to execute for actions.
468+
:param device: The device to run on.
469+
:param input_key: The name of the features. (default: "obs")
470+
"""
471+
472+
super().__init__(
473+
policy=policy,
474+
device=device,
475+
input_keys=(input_key,),
476+
output_keys=("actions",),
477+
)
478+
479+
480+
class VisionAgentProxy(FeatureAgentProxy):
481+
"""This AgentProxy assumes that the observations will contain image observations 'obs'"""
482+
483+
@deprecated(reason="Use GenericAgentProxy instead", version="23.1.0")
484+
def __init__(self, policy: nn.Module, device: torch.device):
485+
super().__init__(policy=policy, device=device, input_key="obs")
486+
487+
488+
class MultiKeyAgentProxy(GenericAgentProxy):
489+
"""Handles multiple input keys.
490+
491+
Observations are dicts that contain multiple input keys (e.g. both "features" and "images").
492+
"""
493+
494+
@deprecated(reason="Use GenericAgentProxy instead", version="23.1.0")
495+
def __init__(
496+
self,
497+
policy: nn.Module,
498+
device: torch.device,
499+
input_keys: tuple,
500+
spaces: MDPSpace = None,
501+
):
502+
"""Create a new proxy.
503+
504+
Args:
505+
policy (nn.Module): The policy to execute for actions.
506+
device (torch.device): The device to run on.
507+
input_keys (tuple): The names of the input.
508+
"""
509+
super().__init__(
510+
policy=policy,
511+
device=device,
512+
input_keys=input_keys,
513+
output_keys=("actions",),
514+
spaces=spaces,
515+
)

emote/utils/deprecated.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
3+
"""
4+
5+
import functools
6+
import warnings
7+
8+
from typing import Callable
9+
10+
11+
def deprecated(
12+
original_function: Callable = None,
13+
*,
14+
reason: str = None,
15+
max_warn_count: int = 10,
16+
version: str = None,
17+
) -> Callable:
18+
"""Function decorator to deprecate an annotated function. Can be used both as a
19+
bare decorator, or with parameters to customize the display of the
20+
message. Writes to logging.warn.
21+
22+
:param original_function: Function to decorate. Automatically passed.
23+
:param message: Message to show. Function name is automatically added.
24+
:param max_warn_count: How many times we will warn for the same function
25+
:returns: the wrapped function
26+
"""
27+
reason = f": {reason}" if reason else ""
28+
version = f" -- deprecated since version {version}" if version else ""
29+
30+
def _decorate(function):
31+
warn_count = 0
32+
33+
name = getattr(function, "__qualname__", function.__name__)
34+
message = f"Call to deprecated function '{name}'{reason}{version}."
35+
36+
@functools.wraps(function)
37+
def _wrapper(*args, **kwargs):
38+
nonlocal warn_count
39+
if warn_count < max_warn_count:
40+
warnings.warn(
41+
message,
42+
DeprecationWarning,
43+
stacklevel=2,
44+
)
45+
warn_count += 0
46+
47+
return function(*args, **kwargs)
48+
49+
return _wrapper
50+
51+
if original_function:
52+
return _decorate(original_function)
53+
54+
return _decorate

0 commit comments

Comments
 (0)