Skip to content

Commit af1b00a

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: fix rewind to preserve initial session state
The rewind logic is updated to ensure that state keys set during session creation are not nullified when rewinding. Previously, any key not present in the state at the rewind point was removed. Now, only keys that have appeared in any event's state delta are considered for nullification during a rewind, preventing the removal of initial session state Close #4933 PiperOrigin-RevId: 905322038
1 parent c65dd55 commit af1b00a

8 files changed

Lines changed: 65 additions & 345 deletions

File tree

src/google/adk/runners.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -646,12 +646,6 @@ async def rewind_async(
646646
session_id=session_id,
647647
get_session_config=run_config.get_session_config,
648648
)
649-
if not rewind_before_invocation_id:
650-
# Guard against matching the synthetic initial-state event that is
651-
# appended by `create_session`; that event has an empty invocation_id by
652-
# design and is not a valid rewind target.
653-
raise ValueError('rewind_before_invocation_id must be non-empty.')
654-
655649
rewind_event_index = -1
656650
for i, event in enumerate(session.events):
657651
if event.invocation_id == rewind_before_invocation_id:
@@ -692,34 +686,16 @@ async def _compute_state_delta_for_rewind(
692686
self, session: Session, rewind_event_index: int
693687
) -> dict[str, Any]:
694688
"""Computes the state delta to reverse changes."""
695-
# State at the rewind point is reconstructed entirely from the event
696-
# stream. Session-scoped initial state from `create_session` is captured
697-
# as a synthetic event by `BaseSessionService._record_initial_state_event`,
698-
# so walking events naturally restores initial values even when a later
699-
# event overwrote them.
700689
state_at_rewind_point: dict[str, Any] = {}
701-
all_event_keys: set[str] = set()
702-
703-
for event in session.events[:rewind_event_index]:
704-
if not event.actions.state_delta:
705-
continue
706-
for k, v in event.actions.state_delta.items():
707-
if k.startswith('app:') or k.startswith('user:'):
708-
continue
709-
all_event_keys.add(k)
710-
if v is None:
711-
state_at_rewind_point.pop(k, None)
712-
else:
713-
state_at_rewind_point[k] = v
714-
715-
# Collect any other keys touched by events after the rewind point so we
716-
# know which keys were ever event-sourced.
717-
for event in session.events[rewind_event_index:]:
718-
if not event.actions.state_delta:
719-
continue
720-
for k in event.actions.state_delta:
721-
if not k.startswith('app:') and not k.startswith('user:'):
722-
all_event_keys.add(k)
690+
for i in range(rewind_event_index):
691+
if session.events[i].actions.state_delta:
692+
for k, v in session.events[i].actions.state_delta.items():
693+
if k.startswith('app:') or k.startswith('user:'):
694+
continue
695+
if v is None:
696+
state_at_rewind_point.pop(k, None)
697+
else:
698+
state_at_rewind_point[k] = v
723699

724700
current_state = session.state
725701
rewind_state_delta = {}
@@ -730,13 +706,12 @@ async def _compute_state_delta_for_rewind(
730706
rewind_state_delta[key] = value_at_rewind
731707

732708
# 2. Set keys to None in rewind_state_delta if they are in current_state
733-
# but not in state_at_rewind_point. Only nullify keys that were
734-
# introduced or modified through events; keys set outside the event
735-
# stream are preserved.
709+
# but not in state_at_rewind_point. These keys were added after the
710+
# rewind point and need to be removed.
736711
for key in current_state:
737712
if key.startswith('app:') or key.startswith('user:'):
738713
continue
739-
if key not in state_at_rewind_point and key in all_event_keys:
714+
if key not in state_at_rewind_point:
740715
rewind_state_delta[key] = None
741716

742717
return rewind_state_delta

src/google/adk/sessions/base_session_service.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@
1818
from typing import Any
1919
from typing import Optional
2020

21-
from google.adk.platform import time as platform_time
2221
from pydantic import BaseModel
2322
from pydantic import Field
2423

2524
from ..events.event import Event
26-
from ..events.event_actions import EventActions
2725
from .session import Session
2826
from .state import State
2927

@@ -162,36 +160,3 @@ def _update_session_state(self, session: Session, event: Event) -> None:
162160
return
163161
for key, value in event.actions.state_delta.items():
164162
session.state.update({key: value})
165-
166-
async def _record_initial_state_event(
167-
self, session: Session, state: Optional[dict[str, Any]]
168-
) -> None:
169-
"""Appends a synthetic event carrying the initial non-temp session state.
170-
171-
Subclasses call this from `create_session` so that initial state flows
172-
through `append_event` (the single state-merging path) and so that
173-
`rewind_async` can restore session-scoped initial values for keys later
174-
overwritten or introduced by subsequent events.
175-
176-
Args:
177-
session: The newly created session to attach the event to.
178-
state: The initial state dict supplied to `create_session`. Temp-prefixed
179-
keys are dropped because temp state is ephemeral and never persisted.
180-
"""
181-
if not state:
182-
return
183-
state_delta = {
184-
k: v for k, v in state.items() if not k.startswith(State.TEMP_PREFIX)
185-
}
186-
if not state_delta:
187-
return
188-
# Round to microseconds so the timestamp roundtrips exactly through
189-
# storage backends that persist timestamps as datetime (microsecond
190-
# precision) — keeps in-memory and reloaded events comparable.
191-
timestamp = round(platform_time.get_time(), 6)
192-
initial_event = Event(
193-
author='user',
194-
timestamp=timestamp,
195-
actions=EventActions(state_delta=dict(state_delta)),
196-
)
197-
await self.append_event(session=session, event=initial_event)

src/google/adk/sessions/database_session_service.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -417,13 +417,11 @@ async def create_session(
417417
state: Optional[dict[str, Any]] = None,
418418
session_id: Optional[str] = None,
419419
) -> Session:
420-
# 1. Ensure app/user state rows exist (append_event requires them) and
421-
# insert an empty session row.
422-
# 2. Build the in-memory session reflecting any pre-existing app/user
423-
# state.
424-
# 3. Apply the caller-supplied initial state through the synthetic event
425-
# in `_record_initial_state_event` so all state writes share a single
426-
# code path.
420+
# 1. Populate states.
421+
# 2. Build storage session object
422+
# 3. Add the object to the table
423+
# 4. Build the session object with generated id
424+
# 5. Return the session
427425
await self._prepare_tables()
428426
schema = self._get_schema_classes()
429427
async with self._rollback_on_exception_session() as sql_session:
@@ -434,7 +432,6 @@ async def create_session(
434432
f"Session with id {session_id} already exists."
435433
)
436434
# Get or create state rows, handling concurrent insert races.
437-
# `append_event` requires the app/user state rows to exist.
438435
storage_app_state = await _get_or_create_state(
439436
sql_session=sql_session,
440437
state_model=schema.StorageAppState,
@@ -448,6 +445,19 @@ async def create_session(
448445
defaults={"app_name": app_name, "user_id": user_id, "state": {}},
449446
)
450447

448+
# Extract state deltas
449+
state_deltas = _session_util.extract_state_delta(state)
450+
app_state_delta = state_deltas["app"]
451+
user_state_delta = state_deltas["user"]
452+
session_state = state_deltas["session"]
453+
454+
# Apply state delta
455+
if app_state_delta:
456+
storage_app_state.state = storage_app_state.state | app_state_delta
457+
if user_state_delta:
458+
storage_user_state.state = storage_user_state.state | user_state_delta
459+
460+
# Store the session
451461
now = datetime.fromtimestamp(platform_time.get_time(), tz=timezone.utc)
452462
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
453463
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
@@ -458,21 +468,20 @@ async def create_session(
458468
app_name=app_name,
459469
user_id=user_id,
460470
id=session_id,
461-
state={},
471+
state=session_state,
462472
create_time=now,
463473
update_time=now,
464474
)
465475
sql_session.add(storage_session)
466476
await sql_session.commit()
467477

478+
# Merge states for response
468479
merged_state = _merge_state(
469-
storage_app_state.state, storage_user_state.state, {}
480+
storage_app_state.state, storage_user_state.state, session_state
470481
)
471482
session = storage_session.to_session(
472483
state=merged_state, is_sqlite=is_sqlite
473484
)
474-
475-
await self._record_initial_state_event(session, state)
476485
return session
477486

478487
@override

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,12 @@ async def create_session(
8383
state: Optional[dict[str, Any]] = None,
8484
session_id: Optional[str] = None,
8585
) -> Session:
86-
# Initial state flows through `_record_initial_state_event` ->
87-
# `append_event` so the in-memory dicts and the event stream are written
88-
# exactly once. The deprecated `create_session_sync` keeps the legacy
89-
# direct-write path because it cannot await `append_event`.
90-
session = self._create_session_impl(
86+
return self._create_session_impl(
9187
app_name=app_name,
9288
user_id=user_id,
93-
state=None,
89+
state=state,
9490
session_id=session_id,
9591
)
96-
await self._record_initial_state_event(session, state)
97-
return session
9892

9993
def create_session_sync(
10094
self,

src/google/adk/sessions/sqlite_session_service.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,25 @@ async def create_session(
179179
f"Session with id {session_id} already exists."
180180
)
181181

182-
# Insert the session row with empty per-session state. Initial state
183-
# (including app:/user:-prefixed keys) is applied through the synthetic
184-
# event below so that all state writes go through `append_event`.
182+
# Extract state deltas
183+
state_deltas = _session_util.extract_state_delta(state)
184+
app_state_delta = state_deltas["app"]
185+
user_state_delta = state_deltas["user"]
186+
session_state = state_deltas["session"]
187+
188+
# Apply state delta and update/insert states atomically
189+
if app_state_delta:
190+
await self._upsert_app_state(db, app_name, app_state_delta, now)
191+
if user_state_delta:
192+
await self._upsert_user_state(
193+
db, app_name, user_id, user_state_delta, now
194+
)
195+
196+
# Fetch current state after upserts
197+
storage_app_state = await self._get_app_state(db, app_name)
198+
storage_user_state = await self._get_user_state(db, app_name, user_id)
199+
200+
# Store the session
185201
await db.execute(
186202
"""
187203
INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time)
@@ -191,19 +207,18 @@ async def create_session(
191207
app_name,
192208
user_id,
193209
session_id,
194-
json.dumps({}),
210+
json.dumps(session_state),
195211
now,
196212
now,
197213
),
198214
)
199215
await db.commit()
200216

201-
# Reflect already-persisted app/user state so subsequent appends start
202-
# from the correct merged view.
203-
storage_app_state = await self._get_app_state(db, app_name)
204-
storage_user_state = await self._get_user_state(db, app_name, user_id)
205-
merged_state = _merge_state(storage_app_state, storage_user_state, {})
206-
session = Session(
217+
# Merge states for response
218+
merged_state = _merge_state(
219+
storage_app_state, storage_user_state, session_state
220+
)
221+
return Session(
207222
app_name=app_name,
208223
user_id=user_id,
209224
id=session_id,
@@ -212,9 +227,6 @@ async def create_session(
212227
last_update_time=now,
213228
)
214229

215-
await self._record_initial_state_event(session, state)
216-
return session
217-
218230
@override
219231
async def get_session(
220232
self,

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,10 @@ async def create_session(
125125
"""
126126
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
127127

128-
# Initial state is persisted exclusively through the synthetic event
129-
# below (which is sent via `events.append`); avoid passing it as
130-
# `session_state` here so the same data is not written to the backend
131-
# twice.
132-
config = dict(kwargs)
128+
config = {'session_state': state} if state else {}
133129
if session_id:
134130
config['session_id'] = session_id
131+
config.update(kwargs)
135132
async with self._get_api_client() as api_client:
136133
api_response = await api_client.agent_engines.sessions.create(
137134
name=f'reasoningEngines/{reasoning_engine_id}',
@@ -146,11 +143,9 @@ async def create_session(
146143
app_name=app_name,
147144
user_id=user_id,
148145
id=session_id,
149-
state={},
146+
state=getattr(get_session_response, 'session_state', None) or {},
150147
last_update_time=get_session_response.update_time.timestamp(),
151148
)
152-
153-
await self._record_initial_state_event(session, state)
154149
return session
155150

156151
@override
@@ -218,21 +213,9 @@ async def get_session(
218213
# to discard events written milliseconds after the session resource was
219214
# updated. Clock skew between those writes can otherwise drop tool_result
220215
# events and permanently break the replayed conversation.
221-
#
222-
# Apply each event's state_delta as we go so callers see the same state
223-
# whether or not the backend mirrors it onto the session_state field
224-
# (e.g. Vertex stores initial state via the synthetic create_session
225-
# event rather than the session_state field).
226216
if events_iterator is not None:
227217
async for event in events_iterator:
228-
adk_event = _from_api_event(event)
229-
session.events.append(adk_event)
230-
if adk_event.actions and adk_event.actions.state_delta:
231-
for key, value in adk_event.actions.state_delta.items():
232-
if value is None:
233-
session.state.pop(key, None)
234-
else:
235-
session.state[key] = value
218+
session.events.append(_from_api_event(event))
236219

237220
if config:
238221
# Filter events based on num_recent_events.

0 commit comments

Comments
 (0)