Skip to content

Commit 784350d

Browse files
petrmarineccopybara-github
authored andcommitted
fix: Scope Vertex RAG memory display names
Merge #5295 ## Summary Fixes #5294. This PR tightens `VertexAiRagMemoryService` memory scoping by replacing the ambiguous dot-delimited RAG `display_name` format with an encoded v1 format for new uploads. Search results are now parsed into exact `app_name`, `user_id`, and `session_id` components before being accepted. The change also keeps compatibility for old unambiguous legacy display names in the exact `app.user.session` form, while ignoring ambiguous legacy names that contain extra dot-delimited components. ## Why The previous client-side filter used: ```python context.source_display_name.startswith(f"{app_name}.{user_id}.") ``` That can collide for IDs such as `alice` and `alice.smith`, allowing memory stored under `demo.alice.smith.*` to pass a lookup for `demo` / `alice`. ## Tests ```text PYTHONPATH=src python -m pytest tests/unittests/memory -q 42 passed, 2 warnings python -m isort --check-only src/google/adk/memory/vertex_ai_rag_memory_service.py tests/unittests/memory/test_vertex_ai_rag_memory_service.py passed python -m pyink --check --config pyproject.toml src/google/adk/memory/vertex_ai_rag_memory_service.py tests/unittests/memory/test_vertex_ai_rag_memory_service.py passed git diff --check -- src/google/adk/memory/vertex_ai_rag_memory_service.py tests/unittests/memory/test_vertex_ai_rag_memory_service.py passed ``` ## Checklist - [x] I have read the CONTRIBUTING.md document. - [x] I have performed a self-review of my own code. - [x] I have added tests that prove my fix is effective. - [x] New and existing unit tests pass locally with my changes. COPYBARA_INTEGRATE_REVIEW=#5295 from petrmarinec:fix/vertex-rag-memory-scope 64a8f5a PiperOrigin-RevId: 905365796
1 parent f1532f9 commit 784350d

2 files changed

Lines changed: 184 additions & 5 deletions

File tree

src/google/adk/memory/vertex_ai_rag_memory_service.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from __future__ import annotations
1717

18+
import base64
19+
import binascii
1820
from collections import OrderedDict
1921
import json
2022
import os
@@ -35,6 +37,58 @@
3537
from ..sessions.session import Session
3638

3739

40+
_SOURCE_DISPLAY_NAME_PREFIX = "adk-memory-v1."
41+
42+
43+
def _encode_source_display_name_part(value: str) -> str:
44+
return (
45+
base64.urlsafe_b64encode(value.encode("utf-8"))
46+
.decode("ascii")
47+
.rstrip("=")
48+
)
49+
50+
51+
def _decode_source_display_name_part(value: str) -> str:
52+
padded_value = value + "=" * (-len(value) % 4)
53+
return base64.b64decode(
54+
padded_value.encode("ascii"), altchars=b"-_", validate=True
55+
).decode("utf-8")
56+
57+
58+
def _build_source_display_name(
59+
app_name: str, user_id: str, session_id: str
60+
) -> str:
61+
return _SOURCE_DISPLAY_NAME_PREFIX + ".".join([
62+
_encode_source_display_name_part(app_name),
63+
_encode_source_display_name_part(user_id),
64+
_encode_source_display_name_part(session_id),
65+
])
66+
67+
68+
def _parse_source_display_name(
69+
source_display_name: str,
70+
) -> tuple[str, str, str] | None:
71+
if source_display_name.startswith(_SOURCE_DISPLAY_NAME_PREFIX):
72+
parts = source_display_name[len(_SOURCE_DISPLAY_NAME_PREFIX) :].split(".")
73+
if len(parts) != 3:
74+
return None
75+
try:
76+
return (
77+
_decode_source_display_name_part(parts[0]),
78+
_decode_source_display_name_part(parts[1]),
79+
_decode_source_display_name_part(parts[2]),
80+
)
81+
except (binascii.Error, UnicodeDecodeError, UnicodeEncodeError):
82+
return None
83+
84+
# Legacy display names were dot-delimited. Only the exact three-part form is
85+
# unambiguous, so dotted app/user/session IDs are intentionally ignored.
86+
parts = source_display_name.split(".")
87+
if len(parts) != 3:
88+
return None
89+
return parts[0], parts[1], parts[2]
90+
91+
3892
class VertexAiRagMemoryService(BaseMemoryService):
3993
"""A memory service that uses Vertex AI RAG for storage and retrieval."""
4094

@@ -63,7 +117,7 @@ def __init__(
63117
)
64118

65119
@override
66-
async def add_session_to_memory(self, session: Session):
120+
async def add_session_to_memory(self, session: Session) -> None:
67121
with tempfile.NamedTemporaryFile(
68122
mode="w", delete=False, suffix=".txt"
69123
) as temp_file:
@@ -100,7 +154,9 @@ async def add_session_to_memory(self, session: Session):
100154
path=temp_file_path,
101155
# this is the temp workaround as upload file does not support
102156
# adding metadata, thus use display_name to store the session info.
103-
display_name=f"{session.app_name}.{session.user_id}.{session.id}",
157+
display_name=_build_source_display_name(
158+
session.app_name, session.user_id, session.id
159+
),
104160
)
105161

106162
os.remove(temp_file_path)
@@ -122,13 +178,19 @@ async def search_memory(
122178
)
123179

124180
memory_results = []
125-
session_events_map = OrderedDict()
181+
session_events_map: OrderedDict[str, list[list[Event]]] = OrderedDict()
126182
for context in response.contexts.contexts:
127183
# filter out context that is not related
128184
# TODO: Add server side filtering by app_name and user_id.
129-
if not context.source_display_name.startswith(f"{app_name}.{user_id}."):
185+
source_display_name = getattr(context, "source_display_name", "")
186+
if not isinstance(source_display_name, str):
187+
continue
188+
session_info = _parse_source_display_name(source_display_name)
189+
if not session_info:
190+
continue
191+
source_app_name, source_user_id, session_id = session_info
192+
if source_app_name != app_name or source_user_id != user_id:
130193
continue
131-
session_id = context.source_display_name.split(".")[-1]
132194
events = []
133195
if context.text:
134196
lines = context.text.split("\n")
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from types import SimpleNamespace
17+
18+
from google.adk.events.event import Event
19+
from google.adk.memory.vertex_ai_rag_memory_service import _build_source_display_name
20+
from google.adk.memory.vertex_ai_rag_memory_service import _SOURCE_DISPLAY_NAME_PREFIX
21+
from google.adk.memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
22+
from google.adk.sessions.session import Session
23+
from google.genai import types
24+
import pytest
25+
26+
27+
def _rag_context(source_display_name: str, text: str) -> SimpleNamespace:
28+
return SimpleNamespace(
29+
source_display_name=source_display_name,
30+
text=json.dumps({"author": "user", "timestamp": 1, "text": text}),
31+
)
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_search_memory_rejects_ambiguous_legacy_display_names(mocker):
36+
"""Ensures dotted user IDs cannot match another user's legacy memory."""
37+
memory_service = VertexAiRagMemoryService(rag_corpus="unused")
38+
fake_rag = SimpleNamespace(
39+
retrieval_query=mocker.Mock(
40+
return_value=SimpleNamespace(
41+
contexts=SimpleNamespace(
42+
contexts=[
43+
_rag_context(
44+
"demo.alice.smith.session_secret",
45+
"SECRET_FROM_ALICE_SMITH",
46+
),
47+
_rag_context(
48+
_build_source_display_name(
49+
"demo", "alice", "session_ok"
50+
),
51+
"NORMAL_ALICE_MEMORY",
52+
),
53+
_rag_context(
54+
"demo.alice.legacy_session",
55+
"LEGACY_ALICE_MEMORY",
56+
),
57+
_rag_context("demo.bob.session_other", "BOB_MEMORY"),
58+
]
59+
)
60+
)
61+
)
62+
)
63+
mocker.patch("google.adk.dependencies.vertexai.rag", fake_rag)
64+
65+
response = await memory_service.search_memory(
66+
app_name="demo", user_id="alice", query="secret"
67+
)
68+
69+
texts = [memory.content.parts[0].text for memory in response.memories]
70+
assert texts == ["NORMAL_ALICE_MEMORY", "LEGACY_ALICE_MEMORY"]
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_add_and_search_memory_uses_unambiguous_display_names(mocker):
75+
memory_service = VertexAiRagMemoryService(rag_corpus="unused")
76+
upload_file = mocker.Mock()
77+
fake_rag = SimpleNamespace(upload_file=upload_file)
78+
mocker.patch("google.adk.dependencies.vertexai.rag", fake_rag)
79+
80+
await memory_service.add_session_to_memory(
81+
Session(
82+
app_name="demo.app",
83+
user_id="alice.smith",
84+
id="session.secret",
85+
last_update_time=1,
86+
events=[
87+
Event(
88+
id="event-1",
89+
author="user",
90+
timestamp=1,
91+
content=types.Content(
92+
parts=[types.Part(text="sensitive memory")]
93+
),
94+
)
95+
],
96+
)
97+
)
98+
99+
display_name = upload_file.call_args.kwargs["display_name"]
100+
assert display_name.startswith(_SOURCE_DISPLAY_NAME_PREFIX)
101+
assert display_name != "demo.app.alice.smith.session.secret"
102+
103+
fake_rag.retrieval_query = mocker.Mock(
104+
return_value=SimpleNamespace(
105+
contexts=SimpleNamespace(
106+
contexts=[_rag_context(display_name, "sensitive memory")]
107+
)
108+
)
109+
)
110+
111+
response = await memory_service.search_memory(
112+
app_name="demo.app", user_id="alice.smith", query="sensitive"
113+
)
114+
115+
assert [memory.content.parts[0].text for memory in response.memories] == [
116+
"sensitive memory"
117+
]

0 commit comments

Comments
 (0)