@@ -662,3 +662,196 @@ async def test_mcp_toolset_outside_workflow_no_not_in_workflow_toolset():
662662 match = "not_in_workflow_toolset" ,
663663 ):
664664 await toolset .get_tools ()
665+
666+
667+ complex_activity_inputs_seen : dict [str , object ] = {}
668+
669+
670+ @activity .defn
671+ async def book_trip (origin : str , destination : str , passengers : int ) -> str :
672+ """Activity that formats multiple discrete arguments."""
673+ complex_activity_inputs_seen ["book_trip" ] = (origin , destination , passengers )
674+ return f"{ origin } ->{ destination } :{ passengers } "
675+
676+
677+ @activity .defn
678+ async def summarize_payload (
679+ name : str , metadata : dict [str , str | int | list [str ]]
680+ ) -> str :
681+ """Activity that formats compound map input."""
682+ complex_activity_inputs_seen ["summarize_payload" ] = (name , metadata )
683+ tags = metadata .get ("tags" , [])
684+ assert isinstance (tags , list )
685+ return f"{ name } :{ metadata ['count' ]} :{ metadata ['owner' ]} :" + "," .join (
686+ str (tag ) for tag in tags
687+ )
688+
689+
690+ class ComplexActivityMethodHolder :
691+ def __init__ (self , prefix : str ) -> None :
692+ self .prefix = prefix
693+
694+ @activity .defn
695+ async def annotate_trip (self , trip : str ) -> str :
696+ complex_activity_inputs_seen ["annotate_trip" ] = trip
697+ return f"{ self .prefix } :{ trip } "
698+
699+
700+ @workflow .defn
701+ class ComplexActivityInputAgent :
702+ @workflow .run
703+ async def run (self , prompt : str , model_name : str ) -> str :
704+ logger .info ("Workflow started." )
705+ method_holder = ComplexActivityMethodHolder ("method" )
706+
707+ agent = Agent (
708+ name = "complex_input_agent" ,
709+ model = TemporalModel (model_name ),
710+ tools = [
711+ temporalio .contrib .google_adk_agents .workflow .activity_tool (
712+ book_trip , start_to_close_timeout = timedelta (seconds = 60 )
713+ ),
714+ temporalio .contrib .google_adk_agents .workflow .activity_tool (
715+ summarize_payload , start_to_close_timeout = timedelta (seconds = 60 )
716+ ),
717+ temporalio .contrib .google_adk_agents .workflow .activity_tool (
718+ method_holder .annotate_trip ,
719+ start_to_close_timeout = timedelta (seconds = 60 ),
720+ ),
721+ ],
722+ )
723+
724+ runner = InMemoryRunner (
725+ agent = agent ,
726+ app_name = "complex_input_app" ,
727+ )
728+
729+ session = await runner .session_service .create_session (
730+ app_name = "complex_input_app" , user_id = "test"
731+ )
732+
733+ final_text = ""
734+ async with Aclosing (
735+ runner .run_async (
736+ user_id = "test" ,
737+ session_id = session .id ,
738+ new_message = types .Content (role = "user" , parts = [types .Part (text = prompt )]),
739+ )
740+ ) as agen :
741+ async for event in agen :
742+ logger .info (f"Event: { event } " )
743+ if event .content and event .content .parts :
744+ for part in event .content .parts :
745+ if part .text is not None :
746+ final_text = part .text
747+
748+ return final_text
749+
750+
751+ class ComplexActivityInputModel (TestModel ):
752+ def responses (self ) -> list [LlmResponse ]:
753+ return [
754+ LlmResponse (
755+ content = Content (
756+ role = "model" ,
757+ parts = [
758+ Part (
759+ function_call = FunctionCall (
760+ name = "book_trip" ,
761+ args = {
762+ "origin" : "SFO" ,
763+ "destination" : "LAX" ,
764+ "passengers" : 3 ,
765+ },
766+ )
767+ )
768+ ],
769+ )
770+ ),
771+ LlmResponse (
772+ content = Content (
773+ role = "model" ,
774+ parts = [
775+ Part (
776+ function_call = FunctionCall (
777+ name = "summarize_payload" ,
778+ args = {
779+ "name" : "fixture" ,
780+ "metadata" : {
781+ "count" : 2 ,
782+ "owner" : "team-a" ,
783+ "tags" : ["alpha" , "beta" ],
784+ },
785+ },
786+ )
787+ )
788+ ],
789+ )
790+ ),
791+ LlmResponse (
792+ content = Content (
793+ role = "model" ,
794+ parts = [
795+ Part (
796+ function_call = FunctionCall (
797+ name = "annotate_trip" ,
798+ args = {"trip" : "SFO->LAX:3" },
799+ )
800+ )
801+ ],
802+ )
803+ ),
804+ LlmResponse (
805+ content = Content (
806+ role = "model" ,
807+ parts = [Part (text = "completed complex input tool calls" )],
808+ )
809+ ),
810+ ]
811+
812+ @classmethod
813+ def supported_models (cls ) -> list [str ]:
814+ return ["complex_activity_input_model" ]
815+
816+
817+ @pytest .mark .asyncio
818+ async def test_activity_tool_supports_complex_inputs_via_adk (client : Client ):
819+ new_config = client .config ()
820+ new_config ["plugins" ] = [GoogleAdkPlugin ()]
821+ client = Client (** new_config )
822+ complex_activity_inputs_seen .clear ()
823+ method_holder = ComplexActivityMethodHolder ("method" )
824+
825+ async with Worker (
826+ client ,
827+ task_queue = "adk-task-queue-complex-inputs" ,
828+ activities = [
829+ book_trip ,
830+ summarize_payload ,
831+ method_holder .annotate_trip ,
832+ ],
833+ workflows = [ComplexActivityInputAgent ],
834+ max_cached_workflows = 0 ,
835+ ):
836+ LLMRegistry .register (ComplexActivityInputModel )
837+
838+ handle = await client .start_workflow (
839+ ComplexActivityInputAgent .run ,
840+ args = [
841+ "Run every registered tool using structured inputs." ,
842+ "complex_activity_input_model" ,
843+ ],
844+ id = f"complex-activity-input-workflow-{ uuid .uuid4 ()} " ,
845+ task_queue = "adk-task-queue-complex-inputs" ,
846+ execution_timeout = timedelta (seconds = 60 ),
847+ )
848+ result = await handle .result ()
849+ assert result == "completed complex input tool calls"
850+ assert complex_activity_inputs_seen == {
851+ "book_trip" : ("SFO" , "LAX" , 3 ),
852+ "summarize_payload" : (
853+ "fixture" ,
854+ {"count" : 2 , "owner" : "team-a" , "tags" : ["alpha" , "beta" ]},
855+ ),
856+ "annotate_trip" : "SFO->LAX:3" ,
857+ }
0 commit comments