@@ -45,9 +45,9 @@ def get_device_list():
4545 if hasattr (torch , "cuda" ) and hasattr (torch .cuda , "is_available" ) and torch .cuda .is_available ():
4646 device_count = torch .cuda .device_count ()
4747 devs += [f"cuda:{ i } " for i in range (device_count )]
48- logger .debug (f"[MultiGPU ] Found { device_count } CUDA device(s)" )
48+ logger .debug (f"[MultiGPU_Device_Utils ] Found { device_count } CUDA device(s)" )
4949 except Exception as e :
50- logger .debug (f"[MultiGPU ] CUDA detection failed: { e } " )
50+ logger .debug (f"[MultiGPU_Device_Utils ] CUDA detection failed: { e } " )
5151
5252 # XPU devices (Intel GPUs)
5353 try :
@@ -59,47 +59,47 @@ def get_device_list():
5959 if hasattr (torch , "xpu" ) and hasattr (torch .xpu , "is_available" ) and torch .xpu .is_available ():
6060 device_count = torch .xpu .device_count ()
6161 devs += [f"xpu:{ i } " for i in range (device_count )]
62- logger .debug (f"[MultiGPU ] Found { device_count } XPU device(s)" )
62+ logger .debug (f"[MultiGPU_Device_Utils ] Found { device_count } XPU device(s)" )
6363 except Exception as e :
64- logger .debug (f"[MultiGPU ] XPU detection failed: { e } " )
64+ logger .debug (f"[MultiGPU_Device_Utils ] XPU detection failed: { e } " )
6565
6666 # NPU devices (Ascend NPUs from Huawei)
6767 try :
6868 import torch_npu
6969 if hasattr (torch , "npu" ) and hasattr (torch .npu , "is_available" ) and torch .npu .is_available ():
7070 device_count = torch .npu .device_count ()
7171 devs += [f"npu:{ i } " for i in range (device_count )]
72- logger .debug (f"[MultiGPU ] Found { device_count } NPU device(s)" )
72+ logger .debug (f"[MultiGPU_Device_Utils ] Found { device_count } NPU device(s)" )
7373 except Exception as e :
74- logger .debug (f"[MultiGPU ] NPU detection failed: { e } " )
74+ logger .debug (f"[MultiGPU_Device_Utils ] NPU detection failed: { e } " )
7575
7676 # MLU devices (Cambricon MLUs)
7777 try :
7878 import torch_mlu
7979 if hasattr (torch , "mlu" ) and hasattr (torch .mlu , "is_available" ) and torch .mlu .is_available ():
8080 device_count = torch .mlu .device_count ()
8181 devs += [f"mlu:{ i } " for i in range (device_count )]
82- logger .debug (f"[MultiGPU ] Found { device_count } MLU device(s)" )
82+ logger .debug (f"[MultiGPU_Device_Utils ] Found { device_count } MLU device(s)" )
8383 except Exception as e :
84- logger .debug (f"[MultiGPU ] MLU detection failed: { e } " )
84+ logger .debug (f"[MultiGPU_Device_Utils ] MLU detection failed: { e } " )
8585
8686 # MPS device (Apple Metal - single device only)
8787 try :
8888 if hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available ():
8989 devs .append ("mps" )
90- logger .debug ("[MultiGPU ] Found MPS device" )
90+ logger .debug ("[MultiGPU_Device_Utils ] Found MPS device" )
9191 except Exception as e :
92- logger .debug (f"[MultiGPU ] MPS detection failed: { e } " )
92+ logger .debug (f"[MultiGPU_Device_Utils ] MPS detection failed: { e } " )
9393
9494 # DirectML devices (Windows DirectML for AMD/Intel/NVIDIA)
9595 try :
9696 import torch_directml
9797 adapter_count = torch_directml .device_count ()
9898 if adapter_count > 0 :
9999 devs += [f"directml:{ i } " for i in range (adapter_count )]
100- logger .debug (f"[MultiGPU ] Found { adapter_count } DirectML adapter(s)" )
100+ logger .debug (f"[MultiGPU_Device_Utils ] Found { adapter_count } DirectML adapter(s)" )
101101 except Exception as e :
102- logger .debug (f"[MultiGPU ] DirectML detection failed: { e } " )
102+ logger .debug (f"[MultiGPU_Device_Utils ] DirectML detection failed: { e } " )
103103
104104 # IXUCA/CoreX devices (special accelerator)
105105 try :
@@ -108,18 +108,18 @@ def get_device_list():
108108 if hasattr (torch .corex , "device_count" ):
109109 device_count = torch .corex .device_count ()
110110 devs += [f"corex:{ i } " for i in range (device_count )]
111- logger .debug (f"[MultiGPU ] Found { device_count } CoreX device(s)" )
111+ logger .debug (f"[MultiGPU_Device_Utils ] Found { device_count } CoreX device(s)" )
112112 else :
113113 devs .append ("corex:0" )
114- logger .debug ("[MultiGPU ] Found CoreX device" )
114+ logger .debug ("[MultiGPU_Device_Utils ] Found CoreX device" )
115115 except Exception as e :
116- logger .debug (f"[MultiGPU ] CoreX detection failed: { e } " )
116+ logger .debug (f"[MultiGPU_Device_Utils ] CoreX detection failed: { e } " )
117117
118118 # Cache the result for future calls
119119 _DEVICE_LIST_CACHE = devs
120120
121121 # Log only once when initially populated
122- logger .info (f"[MultiGPU ] Device list initialized: { devs } " )
122+ logger .info (f"[MultiGPU_Device_Utils ] Device list initialized: { devs } " )
123123
124124 return devs
125125
@@ -218,14 +218,54 @@ def get_device_type(device_string):
218218def parse_device_string (device_string ):
219219 """
220220 Parse a device string into type and index.
221-
221+
222222 Args:
223223 device_string: Device identifier like "cuda:0", "cpu", "xpu:1", etc.
224-
224+
225225 Returns:
226226 Tuple of (device_type, device_index) where index is None for non-indexed devices
227227 """
228228 if ":" in device_string :
229229 parts = device_string .split (":" )
230230 return parts [0 ], int (parts [1 ])
231231 return device_string , None
232+
233+
234+ def soft_empty_cache_multigpu (logger ):
235+ """
236+ Replicate ComfyUI's cache clearing but for ALL devices in MultiGPU.
237+ MultiGPU adaptation of ComfyUI's soft_empty_cache() functionality.
238+ """
239+ import gc
240+
241+ logger .info ("[MultiGPU_Device_Utils] Preparing devices for optimized safetensor loading" )
242+
243+ # Python GC (same as all implementations)
244+ gc .collect ()
245+ logger .debug ("[MultiGPU_Device_Utils] Performed garbage collection before safetensor loading" )
246+
247+ # Clear cache for ALL devices (not just ComfyUI's single device)
248+ all_devices = get_device_list ()
249+
250+ for device_str in all_devices :
251+ if device_str .startswith ("cuda:" ):
252+ device_idx = int (device_str .split (":" )[1 ])
253+ torch .cuda .set_device (device_idx )
254+ torch .cuda .empty_cache ()
255+ torch .cuda .ipc_collect () # ComfyUI's CUDA optimization
256+ logger .debug (f"[MultiGPU_Device_Utils] Cleared cache + IPC for { device_str } " )
257+ elif device_str == "mps" :
258+ torch .mps .empty_cache ()
259+ logger .debug ("[MultiGPU_Device_Utils] Cleared cache for MPS" )
260+ elif device_str .startswith ("xpu:" ):
261+ torch .xpu .empty_cache ()
262+ logger .debug ("[MultiGPU_Device_Utils] Cleared cache for Intel XPU" )
263+ elif device_str .startswith ("npu:" ):
264+ torch .npu .empty_cache ()
265+ logger .debug ("[MultiGPU_Device_Utils] Cleared cache for Ascend NPU" )
266+ elif device_str .startswith ("mlu:" ):
267+ torch .mlu .empty_cache ()
268+ logger .debug ("[MultiGPU_Device_Utils] Cleared cache for Cambricon MLU" )
269+ elif device_str .startswith ("corex:" ):
270+ torch .corex .empty_cache () # Hypothetical based on ComfyUI's ixuca support
271+ logger .debug ("[MultiGPU_Device_Utils] Cleared cache for CoreX" )
0 commit comments