Skip to content

Commit 1951513

Browse files
committed
refactor: migrate DisTorch2 allocation tracking to per-model metadata
- Replace global safetensor_allocation_store/safetensor_settings_store and create_safetensor_model_hash with a per-model annotation (_distorch_v2_meta) stored directly on the inner model object. - Update distorch_2 to remove global stores and hash creation; parse and consume allocation strings from inner_model._distorch_v2_meta during model registration and loading. - Update wrappers, checkpoint_multigpu, device_utils, and __init__ to set and read the new metadata instead of writing/reading global stores. - Simplify detection of DisTorch-managed models (check inner_model._distorch_v2_meta) and adjust logging to surface inner model ids and allocation info. - Clean up related imports and dead code paths. Files changed: distorch_2.py, wrappers.py, checkpoint_multigpu.py, device_utils.py, model_management_mgpu.py, __init__.py
1 parent bdd612c commit 1951513

6 files changed

Lines changed: 84 additions & 157 deletions

File tree

__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,6 @@ def unet_offload_device_patched():
269269
override_class_with_distorch_safetensor_v2_clip_no_device,
270270
)
271271
from .distorch_2 import (
272-
safetensor_allocation_store,
273-
create_safetensor_model_hash,
274272
register_patched_safetensor_modelpatcher,
275273
analyze_safetensor_loading,
276274
calculate_safetensor_vvram_allocation,

checkpoint_multigpu.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from comfy.sd import VAE, CLIP
1010
from .device_utils import get_device_list, soft_empty_cache_multigpu
1111
from .model_management_mgpu import multigpu_memory_log
12-
from .distorch_2 import safetensor_allocation_store, safetensor_settings_store, create_safetensor_model_hash, register_patched_safetensor_modelpatcher
12+
from .distorch_2 import register_patched_safetensor_modelpatcher
1313

1414
logger = logging.getLogger("MultiGPU")
1515

@@ -108,12 +108,10 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
108108

109109
if distorch_config and 'unet_allocation' in distorch_config:
110110
register_patched_safetensor_modelpatcher()
111-
model_hash = create_safetensor_model_hash(model_patcher, "checkpoint_loader_unet")
112-
safetensor_allocation_store[model_hash] = distorch_config['unet_allocation']
113-
safetensor_settings_store[model_hash] = distorch_config.get('unet_settings','')
114-
model.is_distorch = True
111+
inner_model = model_patcher.model
112+
inner_model._distorch_v2_meta = {"full_allocation": distorch_config['unet_allocation']}
113+
logger.info(f"[CHECKPOINT_META] UNET inner_model id=0x{id(inner_model):x}")
115114
model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
116-
logger.mgpu_mm_log(f"Stored DisTorch2 config for UNet (hash {model_hash[:8]}): {distorch_config['unet_allocation']}")
117115

118116
model.load_model_weights(sd, diffusion_model_prefix)
119117
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")
@@ -145,12 +143,10 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
145143
if distorch_config and 'clip_allocation' in distorch_config:
146144
if hasattr(clip, 'patcher'):
147145
register_patched_safetensor_modelpatcher()
148-
clip_hash = create_safetensor_model_hash(clip.patcher, "checkpoint_loader_clip")
149-
safetensor_allocation_store[clip_hash] = distorch_config['clip_allocation']
150-
safetensor_settings_store[clip_hash] = distorch_config.get('clip_settings','')
151-
clip.patcher.model.is_distorch = True
146+
inner_clip = clip.patcher.model
147+
inner_clip._distorch_v2_meta = {"full_allocation": distorch_config['clip_allocation']}
148+
logger.info(f"[CHECKPOINT_META] CLIP inner_model id=0x{id(inner_clip):x}")
152149
clip.patcher.model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
153-
logger.info(f"Stored DisTorch2 config for CLIP (hash {clip_hash[:8]}): {distorch_config['clip_allocation']}")
154150

155151
m, u = clip.load_sd(clip_sd, full_model=True) # This respects the patched text_encoder_device
156152
if len(m) > 0: logger.warning(f"CLIP missing keys: {m}")

device_utils.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -235,30 +235,17 @@ def soft_empty_cache_multigpu():
235235
def soft_empty_cache_distorch2_patched(force=False):
236236
"""Patched mm.soft_empty_cache managing VRAM across all devices, CPU RAM with adaptive thresholding, and DisTorch store pruning."""
237237
from .model_management_mgpu import multigpu_memory_log, check_cpu_memory_threshold, trigger_executor_cache_reset
238-
from .distorch_2 import safetensor_allocation_store, create_safetensor_model_hash
239238

240239
is_distorch_active = False
241240

242-
# Detect DisTorch2-managed models
243-
# logger.mgpu_mm_log(f"[DETECT_DEBUG] Checking DisTorch2 active status - loaded models: {len(mm.current_loaded_models)}, store entries: {len(safetensor_allocation_store)}")
244-
245241
for i, lm in enumerate(mm.current_loaded_models):
246-
mp = lm.model # weakref call to ModelPatcher
242+
mp = lm.model
247243
if mp is not None:
248-
model_hash = create_safetensor_model_hash(mp, "cache_patch_check")
249-
in_store = model_hash in safetensor_allocation_store
250-
alloc_value = safetensor_allocation_store.get(model_hash, "")
251-
model_name = type(getattr(mp, 'model', mp)).__name__
252-
unload_distorch_model = getattr(getattr(mp, 'model', None), '_mgpu_unload_distorch_model', False)
253-
254-
#logger.mgpu_mm_log(f"[DETECT_DEBUG] Model {i}: {model_name}, hash={model_hash[:8]}, in_store={in_store}, alloc_value='{alloc_value}', unload_distorch_model={unload_distorch_model}")
244+
inner_model = mp.model
255245

256-
if in_store and alloc_value:
246+
if hasattr(inner_model, '_distorch_v2_meta'):
257247
is_distorch_active = True
258-
#logger.mgpu_mm_log(f"[DETECT_DEBUG] DisTorch2 ACTIVE detected on model: {model_name}")
259248
break
260-
261-
#logger.mgpu_mm_log(f"[DETECT_DEBUG] Final DisTorch2 active status: {is_distorch_active}")
262249

263250
# Phase 2: adaptive CPU memory management
264251
check_cpu_memory_threshold()

distorch_2.py

Lines changed: 39 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,6 @@
2020
from .model_management_mgpu import multigpu_memory_log, force_full_system_cleanup
2121

2222

23-
safetensor_allocation_store = {}
24-
safetensor_settings_store = {}
25-
26-
27-
def create_safetensor_model_hash(model, caller):
28-
"""Create a unique hash for a safetensor model to track allocations"""
29-
if hasattr(model, 'model'):
30-
# For ModelPatcher objects
31-
actual_model = model.model
32-
model_type = type(actual_model).__name__
33-
# Use ComfyUI's model_size if available
34-
if hasattr(model, 'model_size'):
35-
model_size = model.model_size()
36-
else:
37-
model_size = sum(p.numel() * p.element_size() for p in actual_model.parameters())
38-
if hasattr(model, 'model_state_dict'):
39-
first_layers = str(list(model.model_state_dict().keys())[:3])
40-
else:
41-
first_layers = str(list(actual_model.state_dict().keys())[:3])
42-
else:
43-
# Direct model
44-
model_type = type(model).__name__
45-
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
46-
first_layers = str(list(model.state_dict().keys())[:3])
47-
48-
identifier = f"{model_type}_{model_size}_{first_layers}"
49-
final_hash = hashlib.sha256(identifier.encode()).hexdigest()
50-
51-
# DEBUG STATEMENT - ALWAYS LOG THE HASH
52-
logger.debug(f"[MultiGPU DisTorch V2] Created hash for {caller}: {final_hash[:8]}...")
53-
return final_hash
54-
5523
def register_patched_safetensor_modelpatcher():
5624
"""Register and patch the ModelPatcher for distributed safetensor loading"""
5725
from comfy.model_patcher import wipe_lowvram_weight, move_weight_functions
@@ -128,23 +96,35 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False
12896
device = loaded_model.device
12997
base_memory = loaded_model.model_memory_required(device)
13098

131-
# Check DisTorch flags
132-
is_distorch = hasattr(loaded_model.model.model, '_mgpu_virtual_vram_gb')
133-
has_eject = hasattr(loaded_model.model.model, '_mgpu_eject_models')
134-
135-
if has_eject:
136-
eject_device = device
137-
logger.mgpu_mm_log("DisTorch eject_models=True, is_distorch=True - MAX memory eviction")
138-
139-
if is_distorch:
140-
# is_distorch=True: use compute device allocation size
141-
virtual_vram_gb = loaded_model.model.model._mgpu_virtual_vram_gb
99+
inner_model = loaded_model.model.model
100+
101+
if hasattr(inner_model, '_distorch_v2_meta'):
102+
meta = inner_model._distorch_v2_meta
103+
allocation_str = meta['full_allocation']
104+
105+
# Parse allocation string: "expert#compute_device;virtual_vram_gb;donors"
106+
parts = allocation_str.split('#')
107+
virtual_vram_gb = 0.0
108+
has_eject = False
109+
110+
if len(parts) > 1:
111+
virtual_vram_str = parts[1]
112+
virtual_info = virtual_vram_str.split(';')
113+
if len(virtual_info) > 1:
114+
virtual_vram_gb = float(virtual_info[1])
115+
if len(virtual_info) > 2 and virtual_info[2]:
116+
has_eject = True
117+
118+
if has_eject:
119+
eject_device = device
120+
logger.mgpu_mm_log("DisTorch eject_models detected - MAX memory eviction")
121+
142122
virtual_vram_bytes = virtual_vram_gb * (1024**3)
143123
adjusted_memory = max(0, base_memory - virtual_vram_bytes)
144124
total_memory_required[device] = total_memory_required.get(device, 0) + adjusted_memory
145-
logger.mgpu_mm_log(f"DisTorch is_distorch=True, model adjusted {(base_memory - virtual_vram_bytes)/(1024**3):.2f}GB for device {device}")
125+
logger.mgpu_mm_log(f"DisTorch model adjusted {(base_memory - virtual_vram_bytes)/(1024**3):.2f}GB for device {device}")
146126
else:
147-
# is_distorch=False: use full model size
127+
# Standard model: use full model size
148128
total_memory_required[device] = total_memory_required.get(device, 0) + base_memory
149129
logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] Standard model {(base_memory)/(1024**3):.2f}GB for device {device}")
150130

@@ -209,23 +189,24 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False
209189
original_partially_load = comfy.model_patcher.ModelPatcher.partially_load
210190

211191
def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_patch_weights=False, **kwargs):
212-
"""Override to use our static device assignments"""
213-
global safetensor_allocation_store
214-
215-
debug_hash = create_safetensor_model_hash(self, "partial_load")
216-
multigpu_memory_log(f"safetensor:{debug_hash[:8]}", "pre-load")
217-
allocations = safetensor_allocation_store.get(debug_hash)
218-
219-
# Set default precision flag before checking
220-
if not hasattr(self.model, '_distorch_high_precision_loras'):
221-
self.model._distorch_high_precision_loras = True
222-
223-
if not allocations:
192+
"""Override to use direct model annotation for allocation"""
193+
194+
mp_id = id(self)
195+
mp_patches_uuid = self.patches_uuid
196+
inner_model = self.model
197+
inner_model_id = id(inner_model)
198+
199+
if not hasattr(inner_model, "_distorch_v2_meta"):
200+
logger.debug(f"[DISTORCH_SKIP] ModelPatcher=0x{mp_id:x} inner_model=0x{inner_model_id:x} type={type(inner_model).__name__} - no metadata, using standard loading")
224201
result = original_partially_load(self, device_to, extra_memory, force_patch_weights)
225-
multigpu_memory_log(f"safetensor:{debug_hash[:8]}", "post-load")
226202
if hasattr(self, '_distorch_block_assignments'):
227203
del self._distorch_block_assignments
228204
return result
205+
206+
allocations = inner_model._distorch_v2_meta['full_allocation']
207+
208+
if not hasattr(self.model, '_distorch_high_precision_loras'):
209+
self.model._distorch_high_precision_loras = True
229210

230211
if not hasattr(self.model, 'current_weight_patches_uuid'):
231212
self.model.current_weight_patches_uuid = None
@@ -308,7 +289,6 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
308289

309290
logger.info("[MultiGPU DisTorch V2] DisTorch loading completed.")
310291
logger.info(f"[MultiGPU DisTorch V2] Total memory: {mem_counter / (1024 * 1024):.2f}MB")
311-
multigpu_memory_log(f"safetensor:{debug_hash[:8]}", "post-load")
312292

313293
return 0
314294

model_management_mgpu.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,9 @@
2222
# Model Analysis and Store Management (DisTorch V1 & V2)
2323
# ==========================================================================================
2424

25-
# DisTorch V2 SafeTensor stores
26-
safetensor_allocation_store = {}
27-
safetensor_settings_store = {}
28-
2925
# DisTorch V1 GGUF stores (backwards compatibility)
3026
model_allocation_store = {}
3127

32-
def create_safetensor_model_hash(model, caller):
33-
"""Create a unique hash for a safetensor model to track allocations"""
34-
if hasattr(model, 'model'):
35-
actual_model = model.model
36-
model_type = type(actual_model).__name__
37-
model_size = model.model_size() if hasattr(model, 'model_size') else sum(p.numel() * p.element_size() for p in actual_model.parameters())
38-
first_layers = str(list(model.model_state_dict().keys() if hasattr(model, 'model_state_dict') else actual_model.state_dict().keys())[:3])
39-
else:
40-
model_type = type(model).__name__
41-
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
42-
first_layers = str(list(model.state_dict().keys())[:3])
43-
44-
identifier = f"{model_type}_{model_size}_{first_layers}"
45-
final_hash = hashlib.sha256(identifier.encode()).hexdigest()
46-
logger.debug(f"[MultiGPU DisTorch V2] Created hash for {caller}: {final_hash[:8]}...")
47-
return final_hash
48-
4928
def create_model_hash(model, caller):
5029
"""Create a unique hash for a GGUF model to track allocations (DisTorch V1)"""
5130
model_type = type(model.model).__name__

0 commit comments

Comments
 (0)