99import KratosMultiphysics .CoSimulationApplication .co_simulation_tools as cs_tools
1010import KratosMultiphysics .CoSimulationApplication .colors as colors
1111
12- # other imports
12+ # Other imports
1313from KratosMultiphysics .CoSimulationApplication .utilities import data_communicator_utilities
1414from time import time
15+ from dataclasses import dataclass
16+ from typing import Union , Tuple
17+
18+ @dataclass
19+ class SolverData :
20+ model_part_origin : KM .ModelPart
21+ model_part_destination : KM .ModelPart
22+ model_part_origin_name : str
23+ model_part_destination_name : str
24+ variable_origin : Union [KM .DoubleVariable , KM .Array1DVariable3 ]
25+ variable_destination : Union [KM .DoubleVariable , KM .Array1DVariable3 ]
26+ mapper_flags : KM .Flags
27+ identifier_origin : str
28+ identifier_destination : str
29+ identifier_tuple : Tuple [str , str ]
30+ inverse_identifier_tuple : Tuple [str , str ]
1531
1632def Create (* args ):
1733 return KratosMappingDataTransferOperator (* args )
18-
1934class KratosMappingDataTransferOperator (CoSimulationDataTransferOperator ):
2035 """DataTransferOperator that maps values from one interface (ModelPart) to another.
2136 The mappers of the Kratos-MappingApplication are used
@@ -37,48 +52,34 @@ def __init__(self, settings, parent_coupled_solver_data_communicator):
3752 if not settings .Has ("mapper_settings" ):
3853 raise Exception ('No "mapper_settings" provided!' )
3954 super ().__init__ (settings , parent_coupled_solver_data_communicator )
40- self .__mappers = {}
55+ self ._mappers = {}
4156
4257 def _ExecuteTransferData (self , from_solver_data , to_solver_data , transfer_options ):
43- model_part_origin_name = from_solver_data .model_part_name
44- variable_origin = from_solver_data .variable
45- identifier_origin = from_solver_data .solver_name + "." + model_part_origin_name
46-
47- model_part_destination_name = to_solver_data .model_part_name
48- variable_destination = to_solver_data .variable
49- identifier_destination = to_solver_data .solver_name + "." + model_part_destination_name
50-
51- mapper_flags = self .__GetMapperFlags (transfer_options , from_solver_data , to_solver_data )
58+ solver_data = self ._PrepareSolverData (from_solver_data , to_solver_data , transfer_options )
5259
53- identifier_tuple = (identifier_origin , identifier_destination )
54- inverse_identifier_tuple = (identifier_destination , identifier_origin )
55-
56- if identifier_tuple in self .__mappers :
57- self .__mappers [identifier_tuple ].Map (variable_origin , variable_destination , mapper_flags )
58- elif inverse_identifier_tuple in self .__mappers :
59- self .__mappers [inverse_identifier_tuple ].InverseMap (variable_destination , variable_origin , mapper_flags )
60+ if solver_data .identifier_tuple in self ._mappers :
61+ self ._mappers [solver_data .identifier_tuple ].Map (solver_data .variable_origin , solver_data .variable_destination , solver_data .mapper_flags )
62+ elif solver_data .inverse_identifier_tuple in self ._mappers :
63+ self ._mappers [solver_data .inverse_identifier_tuple ].InverseMap (solver_data .variable_destination , solver_data .variable_origin , solver_data .mapper_flags )
6064 else :
61- model_part_origin = self .__GetModelPartFromInterfaceData (from_solver_data )
62- model_part_destination = self .__GetModelPartFromInterfaceData (to_solver_data )
65+ model_part_origin = self ._GetModelPartFromInterfaceData (from_solver_data )
66+ model_part_destination = self ._GetModelPartFromInterfaceData (to_solver_data )
6367
64- if model_part_origin .IsDistributed () or model_part_destination .IsDistributed ():
65- mapper_create_fct = python_mapper_factory .CreateMPIMapper
66- else :
67- mapper_create_fct = python_mapper_factory .CreateMapper
68+ mapper_create_fct = self ._DefineMapperFunction (model_part_origin , model_part_destination )
6869
6970 if self .echo_level > 0 :
7071 info_msg = "Creating Mapper:\n "
71- info_msg += ' Origin: ModelPart "{}" of solver "{}"\n ' .format (model_part_origin_name , from_solver_data .solver_name )
72- info_msg += ' Destination: ModelPart "{}" of solver "{}"' .format (model_part_destination_name , to_solver_data .solver_name )
72+ info_msg += ' Origin: ModelPart "{}" of solver "{}"\n ' .format (solver_data . model_part_origin_name , from_solver_data .solver_name )
73+ info_msg += ' Destination: ModelPart "{}" of solver "{}"' .format (solver_data . model_part_destination_name , to_solver_data .solver_name )
7374
7475 cs_tools .cs_print_info (colors .bold (self ._ClassName ()), info_msg )
7576
7677 mapper_creation_start_time = time ()
77- self .__mappers [ identifier_tuple ] = mapper_create_fct (model_part_origin , model_part_destination , self .settings ["mapper_settings" ].Clone ()) # Clone is necessary because the settings are validated and defaults assigned, which could influence the creation of other mappers
78+ self ._mappers [ solver_data . identifier_tuple ] = mapper_create_fct (model_part_origin , model_part_destination , self .settings ["mapper_settings" ].Clone ()) # Clone is necessary because the settings are validated and defaults assigned, which could influence the creation of other mappers
7879
7980 if self .echo_level > 2 :
8081 cs_tools .cs_print_info (colors .bold (self ._ClassName ()), "Creating Mapper took: {0:.{1}f} [s]" .format (time ()- mapper_creation_start_time ,2 ))
81- self .__mappers [ identifier_tuple ].Map (variable_origin , variable_destination , mapper_flags )
82+ self ._mappers [ solver_data . identifier_tuple ].Map (solver_data . variable_origin , solver_data . variable_destination , solver_data . mapper_flags )
8283
8384 def _Check (self , from_solver_data , to_solver_data ):
8485 def CheckData (data_to_check ):
@@ -102,19 +103,74 @@ def _GetDefaultParameters(cls):
102103 def _GetListAvailableTransferOptions (cls ):
103104 return cls .__mapper_flags_dict .keys ()
104105
105- def __GetMapperFlags (self , transfer_options , from_solver_data , to_solver_data ):
106- mapper_flags = KM .Flags ()
107- for flag_name in transfer_options .GetStringArray ():
108- mapper_flags |= self .__mapper_flags_dict [flag_name ]
109- if from_solver_data .location == "node_non_historical" :
110- mapper_flags |= KM .Mapper .FROM_NON_HISTORICAL
111- if to_solver_data .location == "node_non_historical" :
112- mapper_flags |= KM .Mapper .TO_NON_HISTORICAL
106+ def _DefineMapperFunction (self , model_part_origin , model_part_destination ):
107+ """
108+ Define the mapper function to be used
113109
114- return mapper_flags
110+ Args:
111+ self: The instance of the class.
112+ model_part_origin (ModelPart): The model part to transfer from.
113+ model_part_destination (ModelPart): The model part to transfer to.
114+
115+ Returns:
116+ function: The function to be used for the mapping
117+ """
118+ if model_part_origin .IsDistributed () or model_part_destination .IsDistributed ():
119+ return python_mapper_factory .CreateMPIMapper
120+ else :
121+ return python_mapper_factory .CreateMapper
122+
123+ def _PrepareSolverData (self , from_solver_data , to_solver_data , transfer_options ):
124+ """
125+ Prepare the solver data
126+
127+ Args:
128+ self: The instance of the class.
129+ from_solver_data (CoSimulationData): The data from the solver to transfer from.
130+ to_solver_data (CoSimulationData): The data from the solver to transfer to.
131+ transfer_options (KM.Flags): The flags to be used for the mapping.
132+
133+ Returns:
134+ SolverData: Named tuple containing all the extracted values.
135+ """
136+ # Get the from solver data
137+ model_part_origin_name = from_solver_data .model_part_name
138+ variable_origin = from_solver_data .variable
139+ identifier_origin = from_solver_data .solver_name + "." + model_part_origin_name
140+
141+ # Get the to solver data
142+ model_part_destination_name = to_solver_data .model_part_name
143+ variable_destination = to_solver_data .variable
144+ identifier_destination = to_solver_data .solver_name + "." + model_part_destination_name
145+
146+ # Get the mapper flags
147+ mapper_flags = self .__GetMapperFlags (transfer_options , from_solver_data , to_solver_data )
148+
149+ # Get the identifier tuple
150+ identifier_tuple = (identifier_origin , identifier_destination )
151+ inverse_identifier_tuple = (identifier_destination , identifier_origin )
152+
153+ # Get the model parts
154+ model_part_origin = self ._GetModelPartFromInterfaceData (from_solver_data )
155+ model_part_destination = self ._GetModelPartFromInterfaceData (to_solver_data )
156+
157+ # Return the solver data as data class
158+ return SolverData (
159+ model_part_origin = model_part_origin ,
160+ model_part_destination = model_part_destination ,
161+ model_part_origin_name = model_part_origin_name ,
162+ model_part_destination_name = model_part_destination_name ,
163+ variable_origin = variable_origin ,
164+ variable_destination = variable_destination ,
165+ mapper_flags = mapper_flags ,
166+ identifier_origin = identifier_origin ,
167+ identifier_destination = identifier_destination ,
168+ identifier_tuple = identifier_tuple ,
169+ inverse_identifier_tuple = inverse_identifier_tuple
170+ )
115171
116172 @staticmethod
117- def __GetModelPartFromInterfaceData (interface_data ):
173+ def _GetModelPartFromInterfaceData (interface_data ):
118174 """If the solver does not exist on this rank, then pass a
119175 dummy ModelPart to the Mapper that has a DataCommunicator
120176 that is not defined on this rank
@@ -124,6 +180,17 @@ def __GetModelPartFromInterfaceData(interface_data):
124180 else :
125181 return KratosMappingDataTransferOperator .__GetRankZeroModelPart ()
126182
183+ def __GetMapperFlags (self , transfer_options , from_solver_data , to_solver_data ):
184+ mapper_flags = KM .Flags ()
185+ for flag_name in transfer_options .GetStringArray ():
186+ mapper_flags |= self .__mapper_flags_dict [flag_name ]
187+ if from_solver_data .location == "node_non_historical" :
188+ mapper_flags |= KM .Mapper .FROM_NON_HISTORICAL
189+ if to_solver_data .location == "node_non_historical" :
190+ mapper_flags |= KM .Mapper .TO_NON_HISTORICAL
191+
192+ return mapper_flags
193+
127194 @staticmethod
128195 def __GetRankZeroModelPart ():
129196 if not KM .IsDistributedRun ():
0 commit comments