@@ -77,18 +77,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None):
7777 self .total_rewards = defaultdict (float )
7878 return super ().reset (seed , options )
7979
80- def step (self , action : ActionType ) -> None :
81- super ().step (action )
82- winners = None
83- losers = None
84- for agent in self .terminations :
85- if self .terminations [agent ]:
86- if winners is None :
87- winners = self .get_winners ()
88- losers = [player for player in self .agents if player not in winners ]
89- self .infos [agent ]["winners" ] = winners
90- self .infos [agent ]["losers" ] = losers
91-
9280 def get_winners (self ):
9381 max_reward = max (self .total_rewards .values ())
9482
@@ -101,11 +89,21 @@ def get_winners(self):
10189
10290 def last (self , observe : bool = True ):
10391 """Returns observation, cumulative reward, terminated, truncated, info for the current agent (specified by self.agent_selection)."""
92+
10493 agent = self .agent_selection
105- # if self._cumulative_rewards[agent]!=0:
106- # print("agent:",agent,self._cumulative_rewards[agent])
94+ # this may be miss the last reward for another agent
10795 self .total_rewards [agent ] += self ._cumulative_rewards [agent ]
10896
97+ winners = None
98+ losers = None
99+ for agent in self .terminations :
100+ if self .terminations [agent ]:
101+ if winners is None :
102+ winners = self .get_winners ()
103+ losers = [player for player in self .agents if player not in winners ]
104+ self .infos [agent ]["winners" ] = winners
105+ self .infos [agent ]["losers" ] = losers
106+
109107 return super ().last (observe )
110108
111109
0 commit comments