Skip to content

Commit aa00a68

Browse files
committed
feat: add model inspection utilities for tracking and analysis
- Extend module docstring to include inspection capabilities - Add create_model_identifier() to generate unique hashes from model type and size - Add analyze_tensor_locations() to analyze tensor device placement and memory usage - Include imports for hashlib, psutil, and comfy.model_management to support new features These utilities enable end-to-end tracking of model state and placement for better debugging and management in multi-GPU setups.
1 parent edc8a4d commit aa00a68

1 file changed

Lines changed: 276 additions & 2 deletions

File tree

device_utils.py

Lines changed: 276 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""
2-
Device detection and management utilities for ComfyUI-MultiGPU.
3-
Single source of truth for all device enumeration and compatibility checks.
2+
Device detection, management, and inspection utilities for ComfyUI-MultiGPU.
3+
Single source of truth for all device enumeration, compatibility checks, and state inspection.
44
Handles all device types supported by ComfyUI core.
55
"""
66

77
import torch
88
import logging
9+
import hashlib
10+
import psutil
11+
import comfy.model_management as mm
912

1013
logger = logging.getLogger("MultiGPU")
1114

@@ -269,3 +272,274 @@ def soft_empty_cache_multigpu(logger):
269272
elif device_str.startswith("corex:"):
270273
torch.corex.empty_cache() # Hypothetical based on ComfyUI's ixuca support
271274
logger.debug("[MultiGPU_Device_Utils] Cleared cache for CoreX")
275+
276+
277+
# ==========================================================================================
278+
# Model Management Inspection Utilities (End-to-End Tracking)
279+
# ==========================================================================================
280+
281+
def create_model_identifier(model_patcher):
282+
"""Creates a concise, unique identifier for a model patcher based on type and size."""
283+
if not model_patcher or not model_patcher.model:
284+
return "N/A (Detached)"
285+
286+
model = model_patcher.model
287+
model_type = type(model).__name__
288+
289+
# Try the fast path first (using size calculated by ModelPatcher)
290+
try:
291+
model_size = model_patcher.model_size()
292+
except Exception:
293+
model_size = 0
294+
295+
# If the fast path fails or returns 0, perform a safe deep inspection
296+
if model_size == 0:
297+
try:
298+
# Safely inspect parameters without triggering hooks/loads
299+
with model_patcher.use_ejected(skip_and_inject_on_exit_only=True):
300+
# We must iterate parameters() AND buffers() as both consume memory
301+
params = list(model.parameters()) + list(model.buffers())
302+
# Use data_ptr to handle potential weight tying/shared tensors correctly
303+
seen_tensors = set()
304+
for p in params:
305+
if p.data_ptr() not in seen_tensors:
306+
model_size += p.numel() * p.element_size()
307+
seen_tensors.add(p.data_ptr())
308+
except Exception as e:
309+
logger.debug(f"[MultiGPU_Inspection] Error during safe size calculation for identifier: {e}")
310+
return f"{model_type} (ID_Err)"
311+
312+
# Create a hash based on type and calculated size
313+
identifier = f"{model_type}_{model_size}"
314+
model_hash = hashlib.sha256(identifier.encode()).hexdigest()
315+
return f"{model_type} ({model_hash[:8]})"
316+
317+
318+
def analyze_tensor_locations(model_patcher):
319+
"""
320+
Analyzes the physical device placement of model tensors (parameters and buffers).
321+
This provides the Ground Truth location of the data, handling shared weights correctly.
322+
"""
323+
device_summary = {}
324+
seen_tensors = set()
325+
total_memory = 0
326+
327+
if not model_patcher or not model_patcher.model:
328+
return {"error": "Model not available"}, 0
329+
330+
model = model_patcher.model
331+
332+
# Crucial: Use the ejector to ensure we can access the model weights safely
333+
# without interfering with injections, hooks, or triggering unintended loads (like in standard LowVRAM mode).
334+
try:
335+
with model_patcher.use_ejected(skip_and_inject_on_exit_only=True):
336+
# Helper to process tensors (parameters or buffers)
337+
def process_tensor(tensor):
338+
nonlocal total_memory
339+
# Use data_ptr() for unique identification of the underlying memory
340+
if tensor.data_ptr() in seen_tensors:
341+
return
342+
seen_tensors.add(tensor.data_ptr())
343+
344+
if tensor.numel() > 0:
345+
tensor_mem = tensor.numel() * tensor.element_size()
346+
total_memory += tensor_mem
347+
348+
if hasattr(tensor, 'device'):
349+
device = str(tensor.device)
350+
else:
351+
# Handle cases like NF4 quantization or other custom tensors
352+
device = "Unknown/Managed"
353+
354+
if device not in device_summary:
355+
device_summary[device] = {'tensors': 0, 'memory': 0}
356+
357+
device_summary[device]['tensors'] += 1
358+
device_summary[device]['memory'] += tensor_mem
359+
360+
# Iterate over all parameters (weights, biases)
361+
for param in model.parameters():
362+
process_tensor(param)
363+
364+
# Iterate over all buffers (like batch norm running stats)
365+
for buffer in model.buffers():
366+
process_tensor(buffer)
367+
368+
except Exception as e:
369+
logger.error(f"[MultiGPU_Inspection] Error during tensor location analysis: {e}")
370+
return {"error": str(e)}, 0
371+
372+
return device_summary, total_memory
373+
374+
375+
def inspect_model_management_state(context_description=""):
376+
"""
377+
Provides a detailed, structured overview of the current state of ComfyUI's model management,
378+
including memory usage across all devices and the status, location, and patching of all loaded models.
379+
380+
Call this function anywhere in the code to get an immediate snapshot of the system state.
381+
"""
382+
383+
# Ensure logger configuration (handles calls before full MultiGPU init if needed)
384+
if not logger.handlers:
385+
handler = logging.StreamHandler()
386+
formatter = logging.Formatter('%(message)s')
387+
handler.setFormatter(formatter)
388+
logger.addHandler(handler)
389+
# Default to INFO if log level isn't set by main __init__.py
390+
if logger.level == logging.NOTSET:
391+
logger.setLevel(logging.INFO)
392+
393+
# We inspect the state without forcing GC or cache clearing, which might alter the state we want to observe.
394+
395+
logger.info("\n" + "=" * 100)
396+
logger.info(f" INSPECTION: ComfyUI Model Management State [Context: {context_description}]")
397+
logger.info("=" * 100)
398+
399+
# 1. Device Memory Overview
400+
# Provides context on available resources across the system.
401+
logger.info("--- [1] System Device Memory Overview (GB) ---")
402+
# Sys Free: Memory available to the OS. Torch Alloc: Memory reserved by PyTorch (Active + Cache).
403+
fmt_mem = "{:<12} | {:>10} | {:>10} | {:>10} | {:>15}"
404+
logger.info(fmt_mem.format("Device", "Total", "Sys Free", "Used", "Torch Alloc"))
405+
logger.info("-" * 70)
406+
407+
all_devices = get_device_list()
408+
# Sort devices for consistent display (CPU last)
409+
sorted_devices = sorted(all_devices, key=lambda d: (d == 'cpu', d))
410+
411+
for dev_str in sorted_devices:
412+
try:
413+
device = torch.device(dev_str)
414+
415+
if dev_str == "cpu":
416+
vm = psutil.virtual_memory()
417+
mem_total, mem_free_sys, mem_used = vm.total, vm.available, vm.used
418+
torch_alloc = 0 # Difficult to track accurately for CPU globally
419+
else:
420+
# Use ComfyUI's management functions which account for different backends (CUDA, XPU, etc.)
421+
mem_total = mm.get_total_memory(device)
422+
423+
# get_free_memory returns (system_free, torch_cache_free)
424+
free_info = mm.get_free_memory(device, torch_free_too=True)
425+
if isinstance(free_info, tuple):
426+
mem_free_sys = free_info[0]
427+
else:
428+
mem_free_sys = free_info # Fallback for backends that return single value (like MPS)
429+
430+
mem_used = mem_total - mem_free_sys
431+
432+
# Determine Torch Allocation (Reserved memory) - Specific checks for known backends
433+
torch_alloc = 0
434+
if device.type == 'cuda' and hasattr(torch.cuda, 'memory_stats'):
435+
stats = torch.cuda.memory_stats(device)
436+
torch_alloc = stats.get('reserved_bytes.all.current', 0)
437+
elif device.type == 'xpu' and hasattr(torch, 'xpu') and hasattr(torch.xpu, 'memory_stats'):
438+
stats = torch.xpu.memory_stats(device)
439+
torch_alloc = stats.get('reserved_bytes.all.current', 0)
440+
elif device.type == 'npu' and hasattr(torch, 'npu') and hasattr(torch.npu, 'memory_stats'):
441+
stats = torch.npu.memory_stats(device)
442+
torch_alloc = stats.get('reserved_bytes.all.current', 0)
443+
elif device.type == 'mlu' and hasattr(torch, 'mlu') and hasattr(torch.mlu, 'memory_stats'):
444+
stats = torch.mlu.memory_stats(device)
445+
torch_alloc = stats.get('reserved_bytes.all.current', 0)
446+
# MPS, DirectML, CoreX do not always expose detailed reserved memory stats easily.
447+
448+
logger.info(fmt_mem.format(
449+
dev_str,
450+
f"{mem_total / (1024**3):.2f}",
451+
f"{mem_free_sys / (1024**3):.2f}",
452+
f"{mem_used / (1024**3):.2f}",
453+
f"{torch_alloc / (1024**3):.2f}"
454+
))
455+
except Exception as e:
456+
logger.debug(f"Could not retrieve memory stats for {dev_str}: {e}")
457+
458+
logger.info("-" * 70)
459+
460+
# 2. Loaded Models Inspection (Logical and Physical View)
461+
# mm.current_loaded_models holds the list of models ComfyUI is managing.
462+
loaded_models = mm.current_loaded_models
463+
logger.info(f"\n--- [2] Loaded Models Inspection (Count: {len(loaded_models)}) ---")
464+
465+
if not loaded_models:
466+
logger.info("No models currently managed by comfy.model_management.")
467+
logger.info("=" * 100)
468+
return
469+
470+
for i, lm in enumerate(loaded_models):
471+
logger.info(f"\nModel {i+1}/{len(loaded_models)}:")
472+
473+
# Check lifecycle status
474+
mp = lm.model # weakref call to ModelPatcher
475+
if mp is None:
476+
# ModelPatcher is gone. Check if the underlying model is still alive (potential leak)
477+
if lm.is_dead() and lm.real_model() is not None:
478+
logger.warning(f" [!] Status: LEAK DETECTED (Patcher GC'd, but underlying model {lm.real_model().__class__.__name__} persists)")
479+
else:
480+
logger.info(f" Status: Cleaned Up (Patcher and Model GC'd)")
481+
continue
482+
483+
model_id = create_model_identifier(mp)
484+
logger.info(f" Identifier: {model_id}")
485+
logger.info(f" Status: {'Active (In Use)' if lm.currently_used else 'Idle (Cache)'}")
486+
487+
# A. Logical View (What ComfyUI intends/tracks)
488+
logger.info(" [A] Logical View (ComfyUI Tracking):")
489+
490+
# Devices: Target (Compute) vs Offload (Storage)
491+
logger.info(f" Devices: Target={lm.device} | Offload={mp.offload_device} | Current (Model.device)={mp.current_loaded_device()}")
492+
493+
# Memory Footprint
494+
mem_total = lm.model_memory()
495+
mem_loaded = lm.model_loaded_memory()
496+
mem_offloaded = lm.model_offloaded_memory()
497+
logger.info(f" Memory (MB): Total={mem_total/(1024**2):.2f} | Loaded (on Target)={mem_loaded/(1024**2):.2f} | Offloaded={mem_offloaded/(1024**2):.2f}")
498+
499+
# Management Mode (LowVRAM/DisTorch)
500+
# model_lowvram indicates if ComfyUI is managing this model partially
501+
is_lowvram = getattr(mp.model, 'model_lowvram', False)
502+
lowvram_patches_pending = mp.lowvram_patch_counter()
503+
logger.info(f" Mode: {'Partial Load (LowVRAM/DisTorch)' if is_lowvram else 'Full Load'}")
504+
if is_lowvram:
505+
# This indicates how many weights are being managed by the partial loading system
506+
logger.info(f" Weights Managed by LowVRAM/DisTorch System: {lowvram_patches_pending}")
507+
508+
# Patching (LoRAs, etc.) - Tracking Attach/Detach
509+
num_weight_patches = len(mp.patches)
510+
# Check the UUID applied to the actual weights vs the UUID defined in the patcher
511+
current_weight_uuid = getattr(mp.model, 'current_weight_patches_uuid', None)
512+
weights_synced = (mp.patches_uuid == current_weight_uuid) and (current_weight_uuid is not None)
513+
514+
if num_weight_patches > 0:
515+
status = 'Applied & Synced' if weights_synced else 'Pending/Mismatch (Re-patch needed)'
516+
logger.info(f" Patches: {num_weight_patches} weight patches defined | Status: {status}")
517+
logger.info(f" UUIDs: Defined={str(mp.patches_uuid)[:8]}... | Applied={str(current_weight_uuid)[:8] if current_weight_uuid else 'None'}...")
518+
519+
# B. Physical View (Ground Truth Tensor Locations)
520+
logger.info(" [B] Physical View (Ground Truth Tensor Locations):")
521+
device_summary, calculated_total_mem = analyze_tensor_locations(mp)
522+
523+
if "error" in device_summary:
524+
logger.error(f" Analysis Error: {device_summary['error']}")
525+
continue
526+
527+
if not device_summary:
528+
logger.info(" No tensors found (e.g., fully offloaded CLIP or utility object).")
529+
else:
530+
# Sort devices (CPU last)
531+
sorted_devices = sorted(device_summary.keys(), key=lambda d: (d.startswith("cpu"), d))
532+
fmt_loc = " {:<15} | Tensors: {:>6} | Memory (MB): {:>10.2f} | Percent: {:>6.1f}%"
533+
for device in sorted_devices:
534+
data = device_summary[device]
535+
percent = (data['memory'] / calculated_total_mem) * 100 if calculated_total_mem > 0 else 0
536+
logger.info(fmt_loc.format(device, data['tensors'], data['memory']/(1024**2), percent))
537+
538+
# Verification Check
539+
if abs(calculated_total_mem - mem_total) > (1024*1024): # Allow 1MB difference
540+
logger.warning(f" [!] Verification WARNING: Physical memory ({calculated_total_mem/(1024**2):.2f}MB) differs from logical memory ({mem_total/(1024**2):.2f}MB).")
541+
542+
logger.info("-" * 100)
543+
544+
logger.info("End of Inspection")
545+
logger.info("=" * 100)

0 commit comments

Comments
 (0)