1414from emote .mixins .logging import LoggingMixin
1515from emote .proxies import AgentProxy
1616from emote .typing import AgentId , DictObservation , DictResponse , EpisodeState
17+ from emote .utils .deprecated import deprecated
1718from emote .utils .gamma_matrix import discount , make_gamma_matrix , split_rollouts
1819from emote .utils .spaces import MDPSpace
1920
@@ -348,110 +349,45 @@ def __call__(self, *args, **kwargs):
348349 def input_names (self ):
349350 return self ._inner .input_names
350351
351-
352- class FeatureAgentProxy :
353- """An agent proxy for basic MLPs.
354-
355- This AgentProxy assumes that the observations will contain a single flat array of features.
356- """
357-
358- def __init__ (self , policy : nn .Module , device : torch .device , input_key : str = "obs" ):
359- """Create a new proxy.
360-
361- :param policy: The policy to execute for actions.
362- :param device: The device to run on.
363- :param input_key: The name of the features. (default: "obs")
364- """
365- self .policy = policy
366- self ._end_states = [EpisodeState .TERMINAL , EpisodeState .INTERRUPTED ]
367- self .device = device
368-
369- self ._input_key = input_key
370-
371- def __call__ (
372- self ,
373- observations : Dict [AgentId , DictObservation ],
374- ) -> Dict [AgentId , DictResponse ]:
375- """Runs the policy and returns the actions."""
376- # The network takes observations of size batch x obs for each observation space.
377- assert len (observations ) > 0 , "Observations must not be empty."
378- active_agents = [
379- agent_id
380- for agent_id , obs in observations .items ()
381- if obs .episode_state not in self ._end_states
382- ]
383- tensor_obs = torch .tensor (
384- np .array (
385- [
386- observations [agent_id ].array_data [self ._input_key ]
387- for agent_id in active_agents
388- ]
389- )
390- ).to (self .device )
391-
392- actions = self .policy (tensor_obs )[0 ].detach ().cpu ().numpy ()
393-
394- return {
395- agent_id : DictResponse (list_data = {"actions" : actions [i ]}, scalar_data = {})
396- for i , agent_id in enumerate (active_agents )
397- }
398-
399352 @property
400- def input_names (self ):
401- return (self ._input_key ,)
402-
403-
404- class VisionAgentProxy :
405- """This AgentProxy assumes that the observations will contain image observations 'obs'"""
353+ def output_names (self ):
354+ return self ._inner .output_names
406355
407- def __init__ (self , policy : nn .Module , device : torch .device ):
408- self .policy = policy
409- self ._end_states = [EpisodeState .TERMINAL , EpisodeState .INTERRUPTED ]
410- self .device = device
356+ @property
357+ def policy (self ):
358+ return self ._inner .policy
411359
412- def __call__ (
413- self , observations : Dict [AgentId , DictObservation ]
414- ) -> Dict [AgentId , DictResponse ]:
415- """Runs the policy and returns the actions."""
416- # The network takes observations of size batch x obs for each observation space.
417- assert len (observations ) > 0 , "Observations must not be empty."
418- active_agents = [
419- agent_id
420- for agent_id , obs in observations .items ()
421- if obs .episode_state not in self ._end_states
422- ]
423- np_obs = np .array (
424- [observations [agent_id ].array_data ["obs" ] for agent_id in active_agents ]
425- )
426- tensor_obs = torch .tensor (np_obs ).to (self .device )
427- actions = self .policy (tensor_obs )[0 ].detach ().cpu ().numpy ()
428- return {
429- agent_id : DictResponse (list_data = {"actions" : actions [i ]}, scalar_data = {})
430- for i , agent_id in enumerate (active_agents )
431- }
432360
361+ class GenericAgentProxy (AgentProxy ):
362+ """Observations are dicts that contain multiple input and output keys.
433363
434- class MultiKeyAgentProxy :
435- """Observations are dicts that contain multiple input keys (e.g. both "features" and "images")"""
364+ For example, we might have a policy that takes in both "obs" and
365+ "goal" and outputs "actions". In order to be able to properly
366+ invoke the network it is the responsibility of this proxy to
367+ collate the inputs and decollate the outputs per agent.
368+ """
436369
437370 def __init__ (
438371 self ,
439372 policy : nn .Module ,
440373 device : torch .device ,
441374 input_keys : tuple ,
442- spaces : MDPSpace = None ,
375+ output_keys : tuple ,
376+ spaces : MDPSpace | None = None ,
443377 ):
444378 """Create a new proxy.
445379
446- Args:
447- policy (nn.Module): The policy to execute for actions.
448- device (torch.device): The device to run on.
449- input_keys (tuple): The names of the input.
380+ :param policy (nn.Module): The policy to invoke
381+ :param device (torch.device): The device to run on
382+ :param input_keys (tuple): The names of the inputs to the policy
383+ :param output_keys (tuple): The names of the outputs of the policy
384+ :param spaces (MDPSpace): The spaces of the inputs and outputs
450385 """
451- self .policy = policy
386+ self ._policy = policy
452387 self ._end_states = [EpisodeState .TERMINAL , EpisodeState .INTERRUPTED ]
453388 self .device = device
454389 self .input_keys = input_keys
390+ self .output_keys = output_keys
455391 self ._spaces = spaces
456392
457393 def __call__ (
@@ -467,7 +403,7 @@ def __call__(
467403 if obs .episode_state not in self ._end_states
468404 ]
469405
470- dict_tensor_obs = {}
406+ tensor_obs_list = [ None ] * len ( self . input_keys )
471407 for input_key in self .input_keys :
472408 np_obs = np .array (
473409 [
@@ -482,15 +418,98 @@ def __call__(
482418 np_obs = np .reshape (np_obs , shape )
483419
484420 tensor_obs = torch .tensor (np_obs ).to (self .device )
485- dict_tensor_obs [input_key ] = tensor_obs
421+ index = self .input_keys .index (input_key )
422+ tensor_obs_list [index ] = tensor_obs
486423
487- actions = self .policy (** dict_tensor_obs )[0 ].detach ().cpu ().numpy ()
424+ outputs : tuple [any , ...] = self ._policy (* tensor_obs_list )
425+ # we remove element 1 as we don't need the logprobs here
426+ outputs = outputs [0 :1 ] + outputs [2 :]
488427
489- return {
490- agent_id : DictResponse ( list_data = { "actions" : actions [i ]}, scalar_data = {} )
491- for i , agent_id in enumerate (active_agents )
428+ outputs = {
429+ key : outputs [i ]. detach (). cpu (). numpy ( )
430+ for i , key in enumerate (self . output_keys )
492431 }
493432
433+ agent_data = [
434+ (agent_id , DictResponse (list_data = {}, scalar_data = {}))
435+ for agent_id in active_agents
436+ ]
437+
438+ for i , (_ , response ) in enumerate (agent_data ):
439+ for k , data in outputs .items ():
440+ response .list_data [k ] = data [i ]
441+
442+ return dict (agent_data )
443+
494444 @property
495445 def input_names (self ):
496446 return self .input_keys
447+
448+ @property
449+ def output_names (self ):
450+ return self .output_keys
451+
452+ @property
453+ def policy (self ):
454+ return self ._policy
455+
456+
457+ class FeatureAgentProxy (GenericAgentProxy ):
458+ """An agent proxy for basic MLPs.
459+
460+ This AgentProxy assumes that the observations will contain a single flat array of features.
461+ """
462+
463+ @deprecated (reason = "Use GenericAgentProxy instead" , version = "23.1.0" )
464+ def __init__ (self , policy : nn .Module , device : torch .device , input_key : str = "obs" ):
465+ """Create a new proxy.
466+
467+ :param policy: The policy to execute for actions.
468+ :param device: The device to run on.
469+ :param input_key: The name of the features. (default: "obs")
470+ """
471+
472+ super ().__init__ (
473+ policy = policy ,
474+ device = device ,
475+ input_keys = (input_key ,),
476+ output_keys = ("actions" ,),
477+ )
478+
479+
480+ class VisionAgentProxy (FeatureAgentProxy ):
481+ """This AgentProxy assumes that the observations will contain image observations 'obs'"""
482+
483+ @deprecated (reason = "Use GenericAgentProxy instead" , version = "23.1.0" )
484+ def __init__ (self , policy : nn .Module , device : torch .device ):
485+ super ().__init__ (policy = policy , device = device , input_key = "obs" )
486+
487+
488+ class MultiKeyAgentProxy (GenericAgentProxy ):
489+ """Handles multiple input keys.
490+
491+ Observations are dicts that contain multiple input keys (e.g. both "features" and "images").
492+ """
493+
494+ @deprecated (reason = "Use GenericAgentProxy instead" , version = "23.1.0" )
495+ def __init__ (
496+ self ,
497+ policy : nn .Module ,
498+ device : torch .device ,
499+ input_keys : tuple ,
500+ spaces : MDPSpace = None ,
501+ ):
502+ """Create a new proxy.
503+
504+ Args:
505+ policy (nn.Module): The policy to execute for actions.
506+ device (torch.device): The device to run on.
507+ input_keys (tuple): The names of the input.
508+ """
509+ super ().__init__ (
510+ policy = policy ,
511+ device = device ,
512+ input_keys = input_keys ,
513+ output_keys = ("actions" ,),
514+ spaces = spaces ,
515+ )
0 commit comments