|
20 | 20 | from .model_management_mgpu import multigpu_memory_log, force_full_system_cleanup |
21 | 21 |
|
22 | 22 |
|
23 | | -safetensor_allocation_store = {} |
24 | | -safetensor_settings_store = {} |
25 | | - |
26 | | - |
27 | | -def create_safetensor_model_hash(model, caller): |
28 | | - """Create a unique hash for a safetensor model to track allocations""" |
29 | | - if hasattr(model, 'model'): |
30 | | - # For ModelPatcher objects |
31 | | - actual_model = model.model |
32 | | - model_type = type(actual_model).__name__ |
33 | | - # Use ComfyUI's model_size if available |
34 | | - if hasattr(model, 'model_size'): |
35 | | - model_size = model.model_size() |
36 | | - else: |
37 | | - model_size = sum(p.numel() * p.element_size() for p in actual_model.parameters()) |
38 | | - if hasattr(model, 'model_state_dict'): |
39 | | - first_layers = str(list(model.model_state_dict().keys())[:3]) |
40 | | - else: |
41 | | - first_layers = str(list(actual_model.state_dict().keys())[:3]) |
42 | | - else: |
43 | | - # Direct model |
44 | | - model_type = type(model).__name__ |
45 | | - model_size = sum(p.numel() * p.element_size() for p in model.parameters()) |
46 | | - first_layers = str(list(model.state_dict().keys())[:3]) |
47 | | - |
48 | | - identifier = f"{model_type}_{model_size}_{first_layers}" |
49 | | - final_hash = hashlib.sha256(identifier.encode()).hexdigest() |
50 | | - |
51 | | - # DEBUG STATEMENT - ALWAYS LOG THE HASH |
52 | | - logger.debug(f"[MultiGPU DisTorch V2] Created hash for {caller}: {final_hash[:8]}...") |
53 | | - return final_hash |
54 | | - |
55 | 23 | def register_patched_safetensor_modelpatcher(): |
56 | 24 | """Register and patch the ModelPatcher for distributed safetensor loading""" |
57 | 25 | from comfy.model_patcher import wipe_lowvram_weight, move_weight_functions |
@@ -128,23 +96,35 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False |
128 | 96 | device = loaded_model.device |
129 | 97 | base_memory = loaded_model.model_memory_required(device) |
130 | 98 |
|
131 | | - # Check DisTorch flags |
132 | | - is_distorch = hasattr(loaded_model.model.model, '_mgpu_virtual_vram_gb') |
133 | | - has_eject = hasattr(loaded_model.model.model, '_mgpu_eject_models') |
134 | | - |
135 | | - if has_eject: |
136 | | - eject_device = device |
137 | | - logger.mgpu_mm_log("DisTorch eject_models=True, is_distorch=True - MAX memory eviction") |
138 | | - |
139 | | - if is_distorch: |
140 | | - # is_distorch=True: use compute device allocation size |
141 | | - virtual_vram_gb = loaded_model.model.model._mgpu_virtual_vram_gb |
| 99 | + inner_model = loaded_model.model.model |
| 100 | + |
| 101 | + if hasattr(inner_model, '_distorch_v2_meta'): |
| 102 | + meta = inner_model._distorch_v2_meta |
| 103 | + allocation_str = meta['full_allocation'] |
| 104 | + |
| 105 | + # Parse allocation string: "expert#compute_device;virtual_vram_gb;donors" |
| 106 | + parts = allocation_str.split('#') |
| 107 | + virtual_vram_gb = 0.0 |
| 108 | + has_eject = False |
| 109 | + |
| 110 | + if len(parts) > 1: |
| 111 | + virtual_vram_str = parts[1] |
| 112 | + virtual_info = virtual_vram_str.split(';') |
| 113 | + if len(virtual_info) > 1: |
| 114 | + virtual_vram_gb = float(virtual_info[1]) |
| 115 | + if len(virtual_info) > 2 and virtual_info[2]: |
| 116 | + has_eject = True |
| 117 | + |
| 118 | + if has_eject: |
| 119 | + eject_device = device |
| 120 | + logger.mgpu_mm_log("DisTorch eject_models detected - MAX memory eviction") |
| 121 | + |
142 | 122 | virtual_vram_bytes = virtual_vram_gb * (1024**3) |
143 | 123 | adjusted_memory = max(0, base_memory - virtual_vram_bytes) |
144 | 124 | total_memory_required[device] = total_memory_required.get(device, 0) + adjusted_memory |
145 | | - logger.mgpu_mm_log(f"DisTorch is_distorch=True, model adjusted {(base_memory - virtual_vram_bytes)/(1024**3):.2f}GB for device {device}") |
| 125 | + logger.mgpu_mm_log(f"DisTorch model adjusted {(base_memory - virtual_vram_bytes)/(1024**3):.2f}GB for device {device}") |
146 | 126 | else: |
147 | | - # is_distorch=False: use full model size |
| 127 | + # Standard model: use full model size |
148 | 128 | total_memory_required[device] = total_memory_required.get(device, 0) + base_memory |
149 | 129 | logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] Standard model {(base_memory)/(1024**3):.2f}GB for device {device}") |
150 | 130 |
|
@@ -209,23 +189,24 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False |
209 | 189 | original_partially_load = comfy.model_patcher.ModelPatcher.partially_load |
210 | 190 |
|
211 | 191 | def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_patch_weights=False, **kwargs): |
212 | | - """Override to use our static device assignments""" |
213 | | - global safetensor_allocation_store |
214 | | - |
215 | | - debug_hash = create_safetensor_model_hash(self, "partial_load") |
216 | | - multigpu_memory_log(f"safetensor:{debug_hash[:8]}", "pre-load") |
217 | | - allocations = safetensor_allocation_store.get(debug_hash) |
218 | | - |
219 | | - # Set default precision flag before checking |
220 | | - if not hasattr(self.model, '_distorch_high_precision_loras'): |
221 | | - self.model._distorch_high_precision_loras = True |
222 | | - |
223 | | - if not allocations: |
| 192 | + """Override to use direct model annotation for allocation""" |
| 193 | + |
| 194 | + mp_id = id(self) |
| 195 | + mp_patches_uuid = self.patches_uuid |
| 196 | + inner_model = self.model |
| 197 | + inner_model_id = id(inner_model) |
| 198 | + |
| 199 | + if not hasattr(inner_model, "_distorch_v2_meta"): |
| 200 | + logger.debug(f"[DISTORCH_SKIP] ModelPatcher=0x{mp_id:x} inner_model=0x{inner_model_id:x} type={type(inner_model).__name__} - no metadata, using standard loading") |
224 | 201 | result = original_partially_load(self, device_to, extra_memory, force_patch_weights) |
225 | | - multigpu_memory_log(f"safetensor:{debug_hash[:8]}", "post-load") |
226 | 202 | if hasattr(self, '_distorch_block_assignments'): |
227 | 203 | del self._distorch_block_assignments |
228 | 204 | return result |
| 205 | + |
| 206 | + allocations = inner_model._distorch_v2_meta['full_allocation'] |
| 207 | + |
| 208 | + if not hasattr(self.model, '_distorch_high_precision_loras'): |
| 209 | + self.model._distorch_high_precision_loras = True |
229 | 210 |
|
230 | 211 | if not hasattr(self.model, 'current_weight_patches_uuid'): |
231 | 212 | self.model.current_weight_patches_uuid = None |
@@ -308,7 +289,6 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p |
308 | 289 |
|
309 | 290 | logger.info("[MultiGPU DisTorch V2] DisTorch loading completed.") |
310 | 291 | logger.info(f"[MultiGPU DisTorch V2] Total memory: {mem_counter / (1024 * 1024):.2f}MB") |
311 | | - multigpu_memory_log(f"safetensor:{debug_hash[:8]}", "post-load") |
312 | 292 |
|
313 | 293 | return 0 |
314 | 294 |
|
|
0 commit comments