Skip to content

Commit 6050a6f

Browse files
Merge branch 'master' into servol_loading_with_periodic_boundary
2 parents 9f33833 + 2554f36 commit 6050a6f

365 files changed

Lines changed: 22862 additions & 4826 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

applications/CoSimulationApplication/python_scripts/data_transfer_operators/kratos_mapping.py

Lines changed: 106 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,28 @@
99
import KratosMultiphysics.CoSimulationApplication.co_simulation_tools as cs_tools
1010
import KratosMultiphysics.CoSimulationApplication.colors as colors
1111

12-
# other imports
12+
# Other imports
1313
from KratosMultiphysics.CoSimulationApplication.utilities import data_communicator_utilities
1414
from time import time
15+
from dataclasses import dataclass
16+
from typing import Union, Tuple
17+
18+
@dataclass
19+
class SolverData:
20+
model_part_origin: KM.ModelPart
21+
model_part_destination: KM.ModelPart
22+
model_part_origin_name: str
23+
model_part_destination_name: str
24+
variable_origin: Union[KM.DoubleVariable, KM.Array1DVariable3]
25+
variable_destination: Union[KM.DoubleVariable, KM.Array1DVariable3]
26+
mapper_flags: KM.Flags
27+
identifier_origin: str
28+
identifier_destination: str
29+
identifier_tuple: Tuple[str, str]
30+
inverse_identifier_tuple: Tuple[str, str]
1531

1632
def Create(*args):
1733
return KratosMappingDataTransferOperator(*args)
18-
1934
class KratosMappingDataTransferOperator(CoSimulationDataTransferOperator):
2035
"""DataTransferOperator that maps values from one interface (ModelPart) to another.
2136
The mappers of the Kratos-MappingApplication are used
@@ -37,48 +52,34 @@ def __init__(self, settings, parent_coupled_solver_data_communicator):
3752
if not settings.Has("mapper_settings"):
3853
raise Exception('No "mapper_settings" provided!')
3954
super().__init__(settings, parent_coupled_solver_data_communicator)
40-
self.__mappers = {}
55+
self._mappers = {}
4156

4257
def _ExecuteTransferData(self, from_solver_data, to_solver_data, transfer_options):
43-
model_part_origin_name = from_solver_data.model_part_name
44-
variable_origin = from_solver_data.variable
45-
identifier_origin = from_solver_data.solver_name + "." + model_part_origin_name
46-
47-
model_part_destination_name = to_solver_data.model_part_name
48-
variable_destination = to_solver_data.variable
49-
identifier_destination = to_solver_data.solver_name + "." + model_part_destination_name
50-
51-
mapper_flags = self.__GetMapperFlags(transfer_options, from_solver_data, to_solver_data)
58+
solver_data = self._PrepareSolverData(from_solver_data, to_solver_data, transfer_options)
5259

53-
identifier_tuple = (identifier_origin, identifier_destination)
54-
inverse_identifier_tuple = (identifier_destination, identifier_origin)
55-
56-
if identifier_tuple in self.__mappers:
57-
self.__mappers[identifier_tuple].Map(variable_origin, variable_destination, mapper_flags)
58-
elif inverse_identifier_tuple in self.__mappers:
59-
self.__mappers[inverse_identifier_tuple].InverseMap(variable_destination, variable_origin, mapper_flags)
60+
if solver_data.identifier_tuple in self._mappers:
61+
self._mappers[solver_data.identifier_tuple].Map(solver_data.variable_origin, solver_data.variable_destination, solver_data.mapper_flags)
62+
elif solver_data.inverse_identifier_tuple in self._mappers:
63+
self._mappers[solver_data.inverse_identifier_tuple].InverseMap(solver_data.variable_destination, solver_data.variable_origin, solver_data.mapper_flags)
6064
else:
61-
model_part_origin = self.__GetModelPartFromInterfaceData(from_solver_data)
62-
model_part_destination = self.__GetModelPartFromInterfaceData(to_solver_data)
65+
model_part_origin = self._GetModelPartFromInterfaceData(from_solver_data)
66+
model_part_destination = self._GetModelPartFromInterfaceData(to_solver_data)
6367

64-
if model_part_origin.IsDistributed() or model_part_destination.IsDistributed():
65-
mapper_create_fct = python_mapper_factory.CreateMPIMapper
66-
else:
67-
mapper_create_fct = python_mapper_factory.CreateMapper
68+
mapper_create_fct = self._DefineMapperFunction(model_part_origin, model_part_destination)
6869

6970
if self.echo_level > 0:
7071
info_msg = "Creating Mapper:\n"
71-
info_msg += ' Origin: ModelPart "{}" of solver "{}"\n'.format(model_part_origin_name, from_solver_data.solver_name)
72-
info_msg += ' Destination: ModelPart "{}" of solver "{}"'.format(model_part_destination_name, to_solver_data.solver_name)
72+
info_msg += ' Origin: ModelPart "{}" of solver "{}"\n'.format(solver_data.model_part_origin_name, from_solver_data.solver_name)
73+
info_msg += ' Destination: ModelPart "{}" of solver "{}"'.format(solver_data.model_part_destination_name, to_solver_data.solver_name)
7374

7475
cs_tools.cs_print_info(colors.bold(self._ClassName()), info_msg)
7576

7677
mapper_creation_start_time = time()
77-
self.__mappers[identifier_tuple] = mapper_create_fct(model_part_origin, model_part_destination, self.settings["mapper_settings"].Clone()) # Clone is necessary because the settings are validated and defaults assigned, which could influence the creation of other mappers
78+
self._mappers[solver_data.identifier_tuple] = mapper_create_fct(model_part_origin, model_part_destination, self.settings["mapper_settings"].Clone()) # Clone is necessary because the settings are validated and defaults assigned, which could influence the creation of other mappers
7879

7980
if self.echo_level > 2:
8081
cs_tools.cs_print_info(colors.bold(self._ClassName()), "Creating Mapper took: {0:.{1}f} [s]".format(time()-mapper_creation_start_time,2))
81-
self.__mappers[identifier_tuple].Map(variable_origin, variable_destination, mapper_flags)
82+
self._mappers[solver_data.identifier_tuple].Map(solver_data.variable_origin, solver_data.variable_destination, solver_data.mapper_flags)
8283

8384
def _Check(self, from_solver_data, to_solver_data):
8485
def CheckData(data_to_check):
@@ -102,19 +103,74 @@ def _GetDefaultParameters(cls):
102103
def _GetListAvailableTransferOptions(cls):
103104
return cls.__mapper_flags_dict.keys()
104105

105-
def __GetMapperFlags(self, transfer_options, from_solver_data, to_solver_data):
106-
mapper_flags = KM.Flags()
107-
for flag_name in transfer_options.GetStringArray():
108-
mapper_flags |= self.__mapper_flags_dict[flag_name]
109-
if from_solver_data.location == "node_non_historical":
110-
mapper_flags |= KM.Mapper.FROM_NON_HISTORICAL
111-
if to_solver_data.location == "node_non_historical":
112-
mapper_flags |= KM.Mapper.TO_NON_HISTORICAL
106+
def _DefineMapperFunction(self, model_part_origin, model_part_destination):
107+
"""
108+
Define the mapper function to be used
113109
114-
return mapper_flags
110+
Args:
111+
self: The instance of the class.
112+
model_part_origin (ModelPart): The model part to transfer from.
113+
model_part_destination (ModelPart): The model part to transfer to.
114+
115+
Returns:
116+
function: The function to be used for the mapping
117+
"""
118+
if model_part_origin.IsDistributed() or model_part_destination.IsDistributed():
119+
return python_mapper_factory.CreateMPIMapper
120+
else:
121+
return python_mapper_factory.CreateMapper
122+
123+
def _PrepareSolverData(self, from_solver_data, to_solver_data, transfer_options):
124+
"""
125+
Prepare the solver data
126+
127+
Args:
128+
self: The instance of the class.
129+
from_solver_data (CoSimulationData): The data from the solver to transfer from.
130+
to_solver_data (CoSimulationData): The data from the solver to transfer to.
131+
transfer_options (KM.Flags): The flags to be used for the mapping.
132+
133+
Returns:
134+
SolverData: Named tuple containing all the extracted values.
135+
"""
136+
# Get the from solver data
137+
model_part_origin_name = from_solver_data.model_part_name
138+
variable_origin = from_solver_data.variable
139+
identifier_origin = from_solver_data.solver_name + "." + model_part_origin_name
140+
141+
# Get the to solver data
142+
model_part_destination_name = to_solver_data.model_part_name
143+
variable_destination = to_solver_data.variable
144+
identifier_destination = to_solver_data.solver_name + "." + model_part_destination_name
145+
146+
# Get the mapper flags
147+
mapper_flags = self.__GetMapperFlags(transfer_options, from_solver_data, to_solver_data)
148+
149+
# Get the identifier tuple
150+
identifier_tuple = (identifier_origin, identifier_destination)
151+
inverse_identifier_tuple = (identifier_destination, identifier_origin)
152+
153+
# Get the model parts
154+
model_part_origin = self._GetModelPartFromInterfaceData(from_solver_data)
155+
model_part_destination = self._GetModelPartFromInterfaceData(to_solver_data)
156+
157+
# Return the solver data as data class
158+
return SolverData(
159+
model_part_origin=model_part_origin,
160+
model_part_destination=model_part_destination,
161+
model_part_origin_name=model_part_origin_name,
162+
model_part_destination_name=model_part_destination_name,
163+
variable_origin=variable_origin,
164+
variable_destination=variable_destination,
165+
mapper_flags=mapper_flags,
166+
identifier_origin=identifier_origin,
167+
identifier_destination=identifier_destination,
168+
identifier_tuple=identifier_tuple,
169+
inverse_identifier_tuple=inverse_identifier_tuple
170+
)
115171

116172
@staticmethod
117-
def __GetModelPartFromInterfaceData(interface_data):
173+
def _GetModelPartFromInterfaceData(interface_data):
118174
"""If the solver does not exist on this rank, then pass a
119175
dummy ModelPart to the Mapper that has a DataCommunicator
120176
that is not defined on this rank
@@ -124,6 +180,17 @@ def __GetModelPartFromInterfaceData(interface_data):
124180
else:
125181
return KratosMappingDataTransferOperator.__GetRankZeroModelPart()
126182

183+
def __GetMapperFlags(self, transfer_options, from_solver_data, to_solver_data):
184+
mapper_flags = KM.Flags()
185+
for flag_name in transfer_options.GetStringArray():
186+
mapper_flags |= self.__mapper_flags_dict[flag_name]
187+
if from_solver_data.location == "node_non_historical":
188+
mapper_flags |= KM.Mapper.FROM_NON_HISTORICAL
189+
if to_solver_data.location == "node_non_historical":
190+
mapper_flags |= KM.Mapper.TO_NON_HISTORICAL
191+
192+
return mapper_flags
193+
127194
@staticmethod
128195
def __GetRankZeroModelPart():
129196
if not KM.IsDistributedRun():

0 commit comments

Comments
 (0)