@@ -646,12 +646,6 @@ async def rewind_async(
646646 session_id = session_id ,
647647 get_session_config = run_config .get_session_config ,
648648 )
649- if not rewind_before_invocation_id :
650- # Guard against matching the synthetic initial-state event that is
651- # appended by `create_session`; that event has an empty invocation_id by
652- # design and is not a valid rewind target.
653- raise ValueError ('rewind_before_invocation_id must be non-empty.' )
654-
655649 rewind_event_index = - 1
656650 for i , event in enumerate (session .events ):
657651 if event .invocation_id == rewind_before_invocation_id :
@@ -692,34 +686,16 @@ async def _compute_state_delta_for_rewind(
692686 self , session : Session , rewind_event_index : int
693687 ) -> dict [str , Any ]:
694688 """Computes the state delta to reverse changes."""
695- # State at the rewind point is reconstructed entirely from the event
696- # stream. Session-scoped initial state from `create_session` is captured
697- # as a synthetic event by `BaseSessionService._record_initial_state_event`,
698- # so walking events naturally restores initial values even when a later
699- # event overwrote them.
700689 state_at_rewind_point : dict [str , Any ] = {}
701- all_event_keys : set [str ] = set ()
702-
703- for event in session .events [:rewind_event_index ]:
704- if not event .actions .state_delta :
705- continue
706- for k , v in event .actions .state_delta .items ():
707- if k .startswith ('app:' ) or k .startswith ('user:' ):
708- continue
709- all_event_keys .add (k )
710- if v is None :
711- state_at_rewind_point .pop (k , None )
712- else :
713- state_at_rewind_point [k ] = v
714-
715- # Collect any other keys touched by events after the rewind point so we
716- # know which keys were ever event-sourced.
717- for event in session .events [rewind_event_index :]:
718- if not event .actions .state_delta :
719- continue
720- for k in event .actions .state_delta :
721- if not k .startswith ('app:' ) and not k .startswith ('user:' ):
722- all_event_keys .add (k )
690+ for i in range (rewind_event_index ):
691+ if session .events [i ].actions .state_delta :
692+ for k , v in session .events [i ].actions .state_delta .items ():
693+ if k .startswith ('app:' ) or k .startswith ('user:' ):
694+ continue
695+ if v is None :
696+ state_at_rewind_point .pop (k , None )
697+ else :
698+ state_at_rewind_point [k ] = v
723699
724700 current_state = session .state
725701 rewind_state_delta = {}
@@ -730,13 +706,12 @@ async def _compute_state_delta_for_rewind(
730706 rewind_state_delta [key ] = value_at_rewind
731707
732708 # 2. Set keys to None in rewind_state_delta if they are in current_state
733- # but not in state_at_rewind_point. Only nullify keys that were
734- # introduced or modified through events; keys set outside the event
735- # stream are preserved.
709+ # but not in state_at_rewind_point. These keys were added after the
710+ # rewind point and need to be removed.
736711 for key in current_state :
737712 if key .startswith ('app:' ) or key .startswith ('user:' ):
738713 continue
739- if key not in state_at_rewind_point and key in all_event_keys :
714+ if key not in state_at_rewind_point :
740715 rewind_state_delta [key ] = None
741716
742717 return rewind_state_delta
0 commit comments