Skip to content

Commit c63b539

Browse files
committed
Additional garbage/cache collection (#101) addressed DisTorch2 Device issue for CLIP hopefully closing (#99,#104)
Add comprehensive memory cache clearing aligned with ComfyUI patterns to improve stability and reduce OOM incidents in multi-device scenarios. **Addresses Memory/Garbage Collection Issues:** - Created `soft_empty_cache_multigpu()` function in device_utils.py - Replicates ComfyUI's cache clearing for all devices (CUDA, MPS, XPU, NPU, MLU) - Includes CUDA IPC collect optimization like ComfyUI - Strategically placed calls before major memory allocations **Addresses CLIP loading issues:** - Fixed DisTorch2 device device varibale management before text encoder operations **`soft_empty_cache_multigpu()` implementation Aligned with ComfyUI's Patterns:** - Called after GC operations - Placed before major memory allocations - Matches ComfyUI's proven memory management strategy - Same device clearing logic for multi-device scenarios
1 parent 0adf219 commit c63b539

5 files changed

Lines changed: 185 additions & 39 deletions

File tree

__init__.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@
2929
# Global device state management
3030
current_device = mm.get_torch_device()
3131
current_text_encoder_device = mm.text_encoder_device()
32+
current_text_encoder_initial_device = mm.text_encoder_device()
3233

3334
def set_current_device(device):
3435
global current_device
3536
current_device = device
3637
logger.info(f"[MultiGPU Initialization] current_device set to: {device}")
3738

3839
def set_current_text_encoder_device(device):
39-
global current_text_encoder_device
40+
global current_text_encoder_device, current_text_encoder_initial_device
4041
current_text_encoder_device = device
4142
current_text_encoder_initial_device = device
4243
logger.info(f"[MultiGPU Initialization] current_text_encoder_device and current_text_encoder_initial_device set to: {device}")
@@ -192,7 +193,8 @@ def check_module_exists(module_path):
192193
register_patched_safetensor_modelpatcher,
193194
analyze_safetensor_loading,
194195
calculate_safetensor_vvram_allocation,
195-
override_class_with_distorch_safetensor_v2
196+
override_class_with_distorch_safetensor_v2,
197+
override_class_with_distorch_safetensor_v2_clip
196198
)
197199

198200
# Import advanced checkpoint loaders
@@ -229,13 +231,13 @@ def check_module_exists(module_path):
229231
# DisTorch 2 SafeTensor nodes for FLUX and other safetensor models
230232
NODE_CLASS_MAPPINGS["UNETLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["UNETLoader"])
231233
NODE_CLASS_MAPPINGS["VAELoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["VAELoader"])
232-
NODE_CLASS_MAPPINGS["CLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["CLIPLoader"])
233-
NODE_CLASS_MAPPINGS["DualCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["DualCLIPLoader"])
234+
NODE_CLASS_MAPPINGS["CLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["CLIPLoader"])
235+
NODE_CLASS_MAPPINGS["DualCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["DualCLIPLoader"])
234236
if "TripleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
235-
NODE_CLASS_MAPPINGS["TripleCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["TripleCLIPLoader"])
237+
NODE_CLASS_MAPPINGS["TripleCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["TripleCLIPLoader"])
236238
if "QuadrupleCLIPLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
237-
NODE_CLASS_MAPPINGS["QuadrupleCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["QuadrupleCLIPLoader"])
238-
NODE_CLASS_MAPPINGS["CLIPVisionLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["CLIPVisionLoader"])
239+
NODE_CLASS_MAPPINGS["QuadrupleCLIPLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["QuadrupleCLIPLoader"])
240+
NODE_CLASS_MAPPINGS["CLIPVisionLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2_clip(GLOBAL_NODE_CLASS_MAPPINGS["CLIPVisionLoader"])
239241
NODE_CLASS_MAPPINGS["CheckpointLoaderSimpleDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"])
240242
NODE_CLASS_MAPPINGS["ControlNetLoaderDisTorch2MultiGPU"] = override_class_with_distorch_safetensor_v2(GLOBAL_NODE_CLASS_MAPPINGS["ControlNetLoader"])
241243
if "DiffusersLoader" in GLOBAL_NODE_CLASS_MAPPINGS:
@@ -307,10 +309,10 @@ def register_and_count(module_names, node_map):
307309
"QuadrupleCLIPLoaderGGUFDisTorchMultiGPU": override_class_with_distorch_clip(QuadrupleCLIPLoaderGGUF),
308310
"UnetLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2(UnetLoaderGGUF),
309311
"UnetLoaderGGUFAdvancedDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2(UnetLoaderGGUFAdvanced),
310-
"CLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2(CLIPLoaderGGUF),
311-
"DualCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2(DualCLIPLoaderGGUF),
312-
"TripleCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2(TripleCLIPLoaderGGUF),
313-
"QuadrupleCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2(QuadrupleCLIPLoaderGGUF),
312+
"CLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip(CLIPLoaderGGUF),
313+
"DualCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip(DualCLIPLoaderGGUF),
314+
"TripleCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip(TripleCLIPLoaderGGUF),
315+
"QuadrupleCLIPLoaderGGUFDisTorch2MultiGPU": override_class_with_distorch_safetensor_v2_clip(QuadrupleCLIPLoaderGGUF),
314316
"UnetLoaderGGUFMultiGPU": override_class(UnetLoaderGGUF),
315317
"UnetLoaderGGUFAdvancedMultiGPU": override_class(UnetLoaderGGUFAdvanced),
316318
"CLIPLoaderGGUFMultiGPU": override_class_clip(CLIPLoaderGGUF),

checkpoint_multigpu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import comfy.model_detection
1313
import comfy.clip_vision
1414
from comfy.sd import VAE, CLIP
15-
from .device_utils import get_device_list
15+
from .device_utils import get_device_list, soft_empty_cache_multigpu
1616
from .distorch_2 import safetensor_allocation_store, safetensor_settings_store, create_safetensor_model_hash, register_patched_safetensor_modelpatcher
1717

1818
logger = logging.getLogger("MultiGPU")
@@ -107,8 +107,9 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
107107

108108
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
109109

110+
soft_empty_cache_multigpu(logger)
110111
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
111-
112+
112113
if distorch_config and 'unet_allocation' in distorch_config:
113114
register_patched_safetensor_modelpatcher()
114115
model_hash = create_safetensor_model_hash(model_patcher, "checkpoint_loader_unet")
@@ -136,6 +137,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
136137
if clip_target is not None:
137138
clip_sd = model_config.process_clip_state_dict(sd)
138139
if len(clip_sd) > 0:
140+
soft_empty_cache_multigpu(logger)
139141
clip_params = comfy.utils.calculate_parameters(clip_sd)
140142
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=clip_params, model_options=te_model_options)
141143

device_utils.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def get_device_list():
4545
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
4646
device_count = torch.cuda.device_count()
4747
devs += [f"cuda:{i}" for i in range(device_count)]
48-
logger.debug(f"[MultiGPU] Found {device_count} CUDA device(s)")
48+
logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} CUDA device(s)")
4949
except Exception as e:
50-
logger.debug(f"[MultiGPU] CUDA detection failed: {e}")
50+
logger.debug(f"[MultiGPU_Device_Utils] CUDA detection failed: {e}")
5151

5252
# XPU devices (Intel GPUs)
5353
try:
@@ -59,47 +59,47 @@ def get_device_list():
5959
if hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available():
6060
device_count = torch.xpu.device_count()
6161
devs += [f"xpu:{i}" for i in range(device_count)]
62-
logger.debug(f"[MultiGPU] Found {device_count} XPU device(s)")
62+
logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} XPU device(s)")
6363
except Exception as e:
64-
logger.debug(f"[MultiGPU] XPU detection failed: {e}")
64+
logger.debug(f"[MultiGPU_Device_Utils] XPU detection failed: {e}")
6565

6666
# NPU devices (Ascend NPUs from Huawei)
6767
try:
6868
import torch_npu
6969
if hasattr(torch, "npu") and hasattr(torch.npu, "is_available") and torch.npu.is_available():
7070
device_count = torch.npu.device_count()
7171
devs += [f"npu:{i}" for i in range(device_count)]
72-
logger.debug(f"[MultiGPU] Found {device_count} NPU device(s)")
72+
logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} NPU device(s)")
7373
except Exception as e:
74-
logger.debug(f"[MultiGPU] NPU detection failed: {e}")
74+
logger.debug(f"[MultiGPU_Device_Utils] NPU detection failed: {e}")
7575

7676
# MLU devices (Cambricon MLUs)
7777
try:
7878
import torch_mlu
7979
if hasattr(torch, "mlu") and hasattr(torch.mlu, "is_available") and torch.mlu.is_available():
8080
device_count = torch.mlu.device_count()
8181
devs += [f"mlu:{i}" for i in range(device_count)]
82-
logger.debug(f"[MultiGPU] Found {device_count} MLU device(s)")
82+
logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} MLU device(s)")
8383
except Exception as e:
84-
logger.debug(f"[MultiGPU] MLU detection failed: {e}")
84+
logger.debug(f"[MultiGPU_Device_Utils] MLU detection failed: {e}")
8585

8686
# MPS device (Apple Metal - single device only)
8787
try:
8888
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
8989
devs.append("mps")
90-
logger.debug("[MultiGPU] Found MPS device")
90+
logger.debug("[MultiGPU_Device_Utils] Found MPS device")
9191
except Exception as e:
92-
logger.debug(f"[MultiGPU] MPS detection failed: {e}")
92+
logger.debug(f"[MultiGPU_Device_Utils] MPS detection failed: {e}")
9393

9494
# DirectML devices (Windows DirectML for AMD/Intel/NVIDIA)
9595
try:
9696
import torch_directml
9797
adapter_count = torch_directml.device_count()
9898
if adapter_count > 0:
9999
devs += [f"directml:{i}" for i in range(adapter_count)]
100-
logger.debug(f"[MultiGPU] Found {adapter_count} DirectML adapter(s)")
100+
logger.debug(f"[MultiGPU_Device_Utils] Found {adapter_count} DirectML adapter(s)")
101101
except Exception as e:
102-
logger.debug(f"[MultiGPU] DirectML detection failed: {e}")
102+
logger.debug(f"[MultiGPU_Device_Utils] DirectML detection failed: {e}")
103103

104104
# IXUCA/CoreX devices (special accelerator)
105105
try:
@@ -108,18 +108,18 @@ def get_device_list():
108108
if hasattr(torch.corex, "device_count"):
109109
device_count = torch.corex.device_count()
110110
devs += [f"corex:{i}" for i in range(device_count)]
111-
logger.debug(f"[MultiGPU] Found {device_count} CoreX device(s)")
111+
logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} CoreX device(s)")
112112
else:
113113
devs.append("corex:0")
114-
logger.debug("[MultiGPU] Found CoreX device")
114+
logger.debug("[MultiGPU_Device_Utils] Found CoreX device")
115115
except Exception as e:
116-
logger.debug(f"[MultiGPU] CoreX detection failed: {e}")
116+
logger.debug(f"[MultiGPU_Device_Utils] CoreX detection failed: {e}")
117117

118118
# Cache the result for future calls
119119
_DEVICE_LIST_CACHE = devs
120120

121121
# Log only once when initially populated
122-
logger.info(f"[MultiGPU] Device list initialized: {devs}")
122+
logger.info(f"[MultiGPU_Device_Utils] Device list initialized: {devs}")
123123

124124
return devs
125125

@@ -218,14 +218,54 @@ def get_device_type(device_string):
218218
def parse_device_string(device_string):
219219
"""
220220
Parse a device string into type and index.
221-
221+
222222
Args:
223223
device_string: Device identifier like "cuda:0", "cpu", "xpu:1", etc.
224-
224+
225225
Returns:
226226
Tuple of (device_type, device_index) where index is None for non-indexed devices
227227
"""
228228
if ":" in device_string:
229229
parts = device_string.split(":")
230230
return parts[0], int(parts[1])
231231
return device_string, None
232+
233+
234+
def soft_empty_cache_multigpu(logger):
235+
"""
236+
Replicate ComfyUI's cache clearing but for ALL devices in MultiGPU.
237+
MultiGPU adaptation of ComfyUI's soft_empty_cache() functionality.
238+
"""
239+
import gc
240+
241+
logger.info("[MultiGPU_Device_Utils] Preparing devices for optimized safetensor loading")
242+
243+
# Python GC (same as all implementations)
244+
gc.collect()
245+
logger.debug("[MultiGPU_Device_Utils] Performed garbage collection before safetensor loading")
246+
247+
# Clear cache for ALL devices (not just ComfyUI's single device)
248+
all_devices = get_device_list()
249+
250+
for device_str in all_devices:
251+
if device_str.startswith("cuda:"):
252+
device_idx = int(device_str.split(":")[1])
253+
torch.cuda.set_device(device_idx)
254+
torch.cuda.empty_cache()
255+
torch.cuda.ipc_collect() # ComfyUI's CUDA optimization
256+
logger.debug(f"[MultiGPU_Device_Utils] Cleared cache + IPC for {device_str}")
257+
elif device_str == "mps":
258+
torch.mps.empty_cache()
259+
logger.debug("[MultiGPU_Device_Utils] Cleared cache for MPS")
260+
elif device_str.startswith("xpu:"):
261+
torch.xpu.empty_cache()
262+
logger.debug("[MultiGPU_Device_Utils] Cleared cache for Intel XPU")
263+
elif device_str.startswith("npu:"):
264+
torch.npu.empty_cache()
265+
logger.debug("[MultiGPU_Device_Utils] Cleared cache for Ascend NPU")
266+
elif device_str.startswith("mlu:"):
267+
torch.mlu.empty_cache()
268+
logger.debug("[MultiGPU_Device_Utils] Cleared cache for Cambricon MLU")
269+
elif device_str.startswith("corex:"):
270+
torch.corex.empty_cache() # Hypothetical based on ComfyUI's ixuca support
271+
logger.debug("[MultiGPU_Device_Utils] Cleared cache for CoreX")

distorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import copy
1313
from collections import defaultdict
1414
import comfy.model_management as mm
15-
from .device_utils import get_device_list
15+
from .device_utils import get_device_list, soft_empty_cache_multigpu
1616

1717
# Global store for model allocations
1818
model_allocation_store = {}
@@ -62,6 +62,7 @@ def new_load(self, *args, force_patch_weights=False, **kwargs):
6262
debug_hash = create_model_hash(self, "patcher")
6363
debug_allocations = model_allocation_store.get(debug_hash)
6464
if debug_allocations:
65+
soft_empty_cache_multigpu(logger)
6566
device_assignments = analyze_ggml_loading(self.model, debug_allocations)['device_assignments']
6667
for device, layers in device_assignments.items():
6768
target_device = torch.device(device)

0 commit comments

Comments
 (0)