Skip to content

Commit 1bd9641

Browse files
authored
Fix Google ADK activity tool argument dispatch (#1421)
* Fix Google ADK activity tool argument dispatch * Format Google ADK activity tool tests
1 parent 6d3351a commit 1bd9641

2 files changed

Lines changed: 202 additions & 1 deletion

File tree

temporalio/contrib/google_adk_agents/workflow.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,15 @@ async def wrapper(*args: Any, **kw: Any):
3838
else:
3939
return result
4040

41-
return await workflow.execute_activity(activity_def, *activity_args, **options)
41+
if not activity_args:
42+
return await workflow.execute_activity(activity_def, **options)
43+
if len(activity_args) == 1:
44+
return await workflow.execute_activity(
45+
activity_def, activity_args[0], **options
46+
)
47+
return await workflow.execute_activity(
48+
activity_def, args=activity_args, **options
49+
)
4250

4351
# Copy metadata
4452
wrapper.__name__ = activity_def.__name__

tests/contrib/google_adk_agents/test_google_adk_agents.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,196 @@ async def test_mcp_toolset_outside_workflow_no_not_in_workflow_toolset():
662662
match="not_in_workflow_toolset",
663663
):
664664
await toolset.get_tools()
665+
666+
667+
complex_activity_inputs_seen: dict[str, object] = {}
668+
669+
670+
@activity.defn
671+
async def book_trip(origin: str, destination: str, passengers: int) -> str:
672+
"""Activity that formats multiple discrete arguments."""
673+
complex_activity_inputs_seen["book_trip"] = (origin, destination, passengers)
674+
return f"{origin}->{destination}:{passengers}"
675+
676+
677+
@activity.defn
678+
async def summarize_payload(
679+
name: str, metadata: dict[str, str | int | list[str]]
680+
) -> str:
681+
"""Activity that formats compound map input."""
682+
complex_activity_inputs_seen["summarize_payload"] = (name, metadata)
683+
tags = metadata.get("tags", [])
684+
assert isinstance(tags, list)
685+
return f"{name}:{metadata['count']}:{metadata['owner']}:" + ",".join(
686+
str(tag) for tag in tags
687+
)
688+
689+
690+
class ComplexActivityMethodHolder:
691+
def __init__(self, prefix: str) -> None:
692+
self.prefix = prefix
693+
694+
@activity.defn
695+
async def annotate_trip(self, trip: str) -> str:
696+
complex_activity_inputs_seen["annotate_trip"] = trip
697+
return f"{self.prefix}:{trip}"
698+
699+
700+
@workflow.defn
701+
class ComplexActivityInputAgent:
702+
@workflow.run
703+
async def run(self, prompt: str, model_name: str) -> str:
704+
logger.info("Workflow started.")
705+
method_holder = ComplexActivityMethodHolder("method")
706+
707+
agent = Agent(
708+
name="complex_input_agent",
709+
model=TemporalModel(model_name),
710+
tools=[
711+
temporalio.contrib.google_adk_agents.workflow.activity_tool(
712+
book_trip, start_to_close_timeout=timedelta(seconds=60)
713+
),
714+
temporalio.contrib.google_adk_agents.workflow.activity_tool(
715+
summarize_payload, start_to_close_timeout=timedelta(seconds=60)
716+
),
717+
temporalio.contrib.google_adk_agents.workflow.activity_tool(
718+
method_holder.annotate_trip,
719+
start_to_close_timeout=timedelta(seconds=60),
720+
),
721+
],
722+
)
723+
724+
runner = InMemoryRunner(
725+
agent=agent,
726+
app_name="complex_input_app",
727+
)
728+
729+
session = await runner.session_service.create_session(
730+
app_name="complex_input_app", user_id="test"
731+
)
732+
733+
final_text = ""
734+
async with Aclosing(
735+
runner.run_async(
736+
user_id="test",
737+
session_id=session.id,
738+
new_message=types.Content(role="user", parts=[types.Part(text=prompt)]),
739+
)
740+
) as agen:
741+
async for event in agen:
742+
logger.info(f"Event: {event}")
743+
if event.content and event.content.parts:
744+
for part in event.content.parts:
745+
if part.text is not None:
746+
final_text = part.text
747+
748+
return final_text
749+
750+
751+
class ComplexActivityInputModel(TestModel):
752+
def responses(self) -> list[LlmResponse]:
753+
return [
754+
LlmResponse(
755+
content=Content(
756+
role="model",
757+
parts=[
758+
Part(
759+
function_call=FunctionCall(
760+
name="book_trip",
761+
args={
762+
"origin": "SFO",
763+
"destination": "LAX",
764+
"passengers": 3,
765+
},
766+
)
767+
)
768+
],
769+
)
770+
),
771+
LlmResponse(
772+
content=Content(
773+
role="model",
774+
parts=[
775+
Part(
776+
function_call=FunctionCall(
777+
name="summarize_payload",
778+
args={
779+
"name": "fixture",
780+
"metadata": {
781+
"count": 2,
782+
"owner": "team-a",
783+
"tags": ["alpha", "beta"],
784+
},
785+
},
786+
)
787+
)
788+
],
789+
)
790+
),
791+
LlmResponse(
792+
content=Content(
793+
role="model",
794+
parts=[
795+
Part(
796+
function_call=FunctionCall(
797+
name="annotate_trip",
798+
args={"trip": "SFO->LAX:3"},
799+
)
800+
)
801+
],
802+
)
803+
),
804+
LlmResponse(
805+
content=Content(
806+
role="model",
807+
parts=[Part(text="completed complex input tool calls")],
808+
)
809+
),
810+
]
811+
812+
@classmethod
813+
def supported_models(cls) -> list[str]:
814+
return ["complex_activity_input_model"]
815+
816+
817+
@pytest.mark.asyncio
818+
async def test_activity_tool_supports_complex_inputs_via_adk(client: Client):
819+
new_config = client.config()
820+
new_config["plugins"] = [GoogleAdkPlugin()]
821+
client = Client(**new_config)
822+
complex_activity_inputs_seen.clear()
823+
method_holder = ComplexActivityMethodHolder("method")
824+
825+
async with Worker(
826+
client,
827+
task_queue="adk-task-queue-complex-inputs",
828+
activities=[
829+
book_trip,
830+
summarize_payload,
831+
method_holder.annotate_trip,
832+
],
833+
workflows=[ComplexActivityInputAgent],
834+
max_cached_workflows=0,
835+
):
836+
LLMRegistry.register(ComplexActivityInputModel)
837+
838+
handle = await client.start_workflow(
839+
ComplexActivityInputAgent.run,
840+
args=[
841+
"Run every registered tool using structured inputs.",
842+
"complex_activity_input_model",
843+
],
844+
id=f"complex-activity-input-workflow-{uuid.uuid4()}",
845+
task_queue="adk-task-queue-complex-inputs",
846+
execution_timeout=timedelta(seconds=60),
847+
)
848+
result = await handle.result()
849+
assert result == "completed complex input tool calls"
850+
assert complex_activity_inputs_seen == {
851+
"book_trip": ("SFO", "LAX", 3),
852+
"summarize_payload": (
853+
"fixture",
854+
{"count": 2, "owner": "team-a", "tags": ["alpha", "beta"]},
855+
),
856+
"annotate_trip": "SFO->LAX:3",
857+
}

0 commit comments

Comments
 (0)