|
1 | 1 | """ |
2 | | -Device detection and management utilities for ComfyUI-MultiGPU. |
3 | | -Single source of truth for all device enumeration and compatibility checks. |
| 2 | +Device detection, management, and inspection utilities for ComfyUI-MultiGPU. |
| 3 | +Single source of truth for all device enumeration, compatibility checks, and state inspection. |
4 | 4 | Handles all device types supported by ComfyUI core. |
5 | 5 | """ |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | import logging |
| 9 | +import hashlib |
| 10 | +import psutil |
| 11 | +import comfy.model_management as mm |
9 | 12 |
|
10 | 13 | logger = logging.getLogger("MultiGPU") |
11 | 14 |
|
@@ -269,3 +272,274 @@ def soft_empty_cache_multigpu(logger): |
269 | 272 | elif device_str.startswith("corex:"): |
270 | 273 | torch.corex.empty_cache() # Hypothetical based on ComfyUI's ixuca support |
271 | 274 | logger.debug("[MultiGPU_Device_Utils] Cleared cache for CoreX") |
| 275 | + |
| 276 | + |
| 277 | +# ========================================================================================== |
| 278 | +# Model Management Inspection Utilities (End-to-End Tracking) |
| 279 | +# ========================================================================================== |
| 280 | + |
| 281 | +def create_model_identifier(model_patcher): |
| 282 | + """Creates a concise, unique identifier for a model patcher based on type and size.""" |
| 283 | + if not model_patcher or not model_patcher.model: |
| 284 | + return "N/A (Detached)" |
| 285 | + |
| 286 | + model = model_patcher.model |
| 287 | + model_type = type(model).__name__ |
| 288 | + |
| 289 | + # Try the fast path first (using size calculated by ModelPatcher) |
| 290 | + try: |
| 291 | + model_size = model_patcher.model_size() |
| 292 | + except Exception: |
| 293 | + model_size = 0 |
| 294 | + |
| 295 | + # If the fast path fails or returns 0, perform a safe deep inspection |
| 296 | + if model_size == 0: |
| 297 | + try: |
| 298 | + # Safely inspect parameters without triggering hooks/loads |
| 299 | + with model_patcher.use_ejected(skip_and_inject_on_exit_only=True): |
| 300 | + # We must iterate parameters() AND buffers() as both consume memory |
| 301 | + params = list(model.parameters()) + list(model.buffers()) |
| 302 | + # Use data_ptr to handle potential weight tying/shared tensors correctly |
| 303 | + seen_tensors = set() |
| 304 | + for p in params: |
| 305 | + if p.data_ptr() not in seen_tensors: |
| 306 | + model_size += p.numel() * p.element_size() |
| 307 | + seen_tensors.add(p.data_ptr()) |
| 308 | + except Exception as e: |
| 309 | + logger.debug(f"[MultiGPU_Inspection] Error during safe size calculation for identifier: {e}") |
| 310 | + return f"{model_type} (ID_Err)" |
| 311 | + |
| 312 | + # Create a hash based on type and calculated size |
| 313 | + identifier = f"{model_type}_{model_size}" |
| 314 | + model_hash = hashlib.sha256(identifier.encode()).hexdigest() |
| 315 | + return f"{model_type} ({model_hash[:8]})" |
| 316 | + |
| 317 | + |
| 318 | +def analyze_tensor_locations(model_patcher): |
| 319 | + """ |
| 320 | + Analyzes the physical device placement of model tensors (parameters and buffers). |
| 321 | + This provides the Ground Truth location of the data, handling shared weights correctly. |
| 322 | + """ |
| 323 | + device_summary = {} |
| 324 | + seen_tensors = set() |
| 325 | + total_memory = 0 |
| 326 | + |
| 327 | + if not model_patcher or not model_patcher.model: |
| 328 | + return {"error": "Model not available"}, 0 |
| 329 | + |
| 330 | + model = model_patcher.model |
| 331 | + |
| 332 | + # Crucial: Use the ejector to ensure we can access the model weights safely |
| 333 | + # without interfering with injections, hooks, or triggering unintended loads (like in standard LowVRAM mode). |
| 334 | + try: |
| 335 | + with model_patcher.use_ejected(skip_and_inject_on_exit_only=True): |
| 336 | + # Helper to process tensors (parameters or buffers) |
| 337 | + def process_tensor(tensor): |
| 338 | + nonlocal total_memory |
| 339 | + # Use data_ptr() for unique identification of the underlying memory |
| 340 | + if tensor.data_ptr() in seen_tensors: |
| 341 | + return |
| 342 | + seen_tensors.add(tensor.data_ptr()) |
| 343 | + |
| 344 | + if tensor.numel() > 0: |
| 345 | + tensor_mem = tensor.numel() * tensor.element_size() |
| 346 | + total_memory += tensor_mem |
| 347 | + |
| 348 | + if hasattr(tensor, 'device'): |
| 349 | + device = str(tensor.device) |
| 350 | + else: |
| 351 | + # Handle cases like NF4 quantization or other custom tensors |
| 352 | + device = "Unknown/Managed" |
| 353 | + |
| 354 | + if device not in device_summary: |
| 355 | + device_summary[device] = {'tensors': 0, 'memory': 0} |
| 356 | + |
| 357 | + device_summary[device]['tensors'] += 1 |
| 358 | + device_summary[device]['memory'] += tensor_mem |
| 359 | + |
| 360 | + # Iterate over all parameters (weights, biases) |
| 361 | + for param in model.parameters(): |
| 362 | + process_tensor(param) |
| 363 | + |
| 364 | + # Iterate over all buffers (like batch norm running stats) |
| 365 | + for buffer in model.buffers(): |
| 366 | + process_tensor(buffer) |
| 367 | + |
| 368 | + except Exception as e: |
| 369 | + logger.error(f"[MultiGPU_Inspection] Error during tensor location analysis: {e}") |
| 370 | + return {"error": str(e)}, 0 |
| 371 | + |
| 372 | + return device_summary, total_memory |
| 373 | + |
| 374 | + |
| 375 | +def inspect_model_management_state(context_description=""): |
| 376 | + """ |
| 377 | + Provides a detailed, structured overview of the current state of ComfyUI's model management, |
| 378 | + including memory usage across all devices and the status, location, and patching of all loaded models. |
| 379 | +
|
| 380 | + Call this function anywhere in the code to get an immediate snapshot of the system state. |
| 381 | + """ |
| 382 | + |
| 383 | + # Ensure logger configuration (handles calls before full MultiGPU init if needed) |
| 384 | + if not logger.handlers: |
| 385 | + handler = logging.StreamHandler() |
| 386 | + formatter = logging.Formatter('%(message)s') |
| 387 | + handler.setFormatter(formatter) |
| 388 | + logger.addHandler(handler) |
| 389 | + # Default to INFO if log level isn't set by main __init__.py |
| 390 | + if logger.level == logging.NOTSET: |
| 391 | + logger.setLevel(logging.INFO) |
| 392 | + |
| 393 | + # We inspect the state without forcing GC or cache clearing, which might alter the state we want to observe. |
| 394 | + |
| 395 | + logger.info("\n" + "=" * 100) |
| 396 | + logger.info(f" INSPECTION: ComfyUI Model Management State [Context: {context_description}]") |
| 397 | + logger.info("=" * 100) |
| 398 | + |
| 399 | + # 1. Device Memory Overview |
| 400 | + # Provides context on available resources across the system. |
| 401 | + logger.info("--- [1] System Device Memory Overview (GB) ---") |
| 402 | + # Sys Free: Memory available to the OS. Torch Alloc: Memory reserved by PyTorch (Active + Cache). |
| 403 | + fmt_mem = "{:<12} | {:>10} | {:>10} | {:>10} | {:>15}" |
| 404 | + logger.info(fmt_mem.format("Device", "Total", "Sys Free", "Used", "Torch Alloc")) |
| 405 | + logger.info("-" * 70) |
| 406 | + |
| 407 | + all_devices = get_device_list() |
| 408 | + # Sort devices for consistent display (CPU last) |
| 409 | + sorted_devices = sorted(all_devices, key=lambda d: (d == 'cpu', d)) |
| 410 | + |
| 411 | + for dev_str in sorted_devices: |
| 412 | + try: |
| 413 | + device = torch.device(dev_str) |
| 414 | + |
| 415 | + if dev_str == "cpu": |
| 416 | + vm = psutil.virtual_memory() |
| 417 | + mem_total, mem_free_sys, mem_used = vm.total, vm.available, vm.used |
| 418 | + torch_alloc = 0 # Difficult to track accurately for CPU globally |
| 419 | + else: |
| 420 | + # Use ComfyUI's management functions which account for different backends (CUDA, XPU, etc.) |
| 421 | + mem_total = mm.get_total_memory(device) |
| 422 | + |
| 423 | + # get_free_memory returns (system_free, torch_cache_free) |
| 424 | + free_info = mm.get_free_memory(device, torch_free_too=True) |
| 425 | + if isinstance(free_info, tuple): |
| 426 | + mem_free_sys = free_info[0] |
| 427 | + else: |
| 428 | + mem_free_sys = free_info # Fallback for backends that return single value (like MPS) |
| 429 | + |
| 430 | + mem_used = mem_total - mem_free_sys |
| 431 | + |
| 432 | + # Determine Torch Allocation (Reserved memory) - Specific checks for known backends |
| 433 | + torch_alloc = 0 |
| 434 | + if device.type == 'cuda' and hasattr(torch.cuda, 'memory_stats'): |
| 435 | + stats = torch.cuda.memory_stats(device) |
| 436 | + torch_alloc = stats.get('reserved_bytes.all.current', 0) |
| 437 | + elif device.type == 'xpu' and hasattr(torch, 'xpu') and hasattr(torch.xpu, 'memory_stats'): |
| 438 | + stats = torch.xpu.memory_stats(device) |
| 439 | + torch_alloc = stats.get('reserved_bytes.all.current', 0) |
| 440 | + elif device.type == 'npu' and hasattr(torch, 'npu') and hasattr(torch.npu, 'memory_stats'): |
| 441 | + stats = torch.npu.memory_stats(device) |
| 442 | + torch_alloc = stats.get('reserved_bytes.all.current', 0) |
| 443 | + elif device.type == 'mlu' and hasattr(torch, 'mlu') and hasattr(torch.mlu, 'memory_stats'): |
| 444 | + stats = torch.mlu.memory_stats(device) |
| 445 | + torch_alloc = stats.get('reserved_bytes.all.current', 0) |
| 446 | + # MPS, DirectML, CoreX do not always expose detailed reserved memory stats easily. |
| 447 | + |
| 448 | + logger.info(fmt_mem.format( |
| 449 | + dev_str, |
| 450 | + f"{mem_total / (1024**3):.2f}", |
| 451 | + f"{mem_free_sys / (1024**3):.2f}", |
| 452 | + f"{mem_used / (1024**3):.2f}", |
| 453 | + f"{torch_alloc / (1024**3):.2f}" |
| 454 | + )) |
| 455 | + except Exception as e: |
| 456 | + logger.debug(f"Could not retrieve memory stats for {dev_str}: {e}") |
| 457 | + |
| 458 | + logger.info("-" * 70) |
| 459 | + |
| 460 | + # 2. Loaded Models Inspection (Logical and Physical View) |
| 461 | + # mm.current_loaded_models holds the list of models ComfyUI is managing. |
| 462 | + loaded_models = mm.current_loaded_models |
| 463 | + logger.info(f"\n--- [2] Loaded Models Inspection (Count: {len(loaded_models)}) ---") |
| 464 | + |
| 465 | + if not loaded_models: |
| 466 | + logger.info("No models currently managed by comfy.model_management.") |
| 467 | + logger.info("=" * 100) |
| 468 | + return |
| 469 | + |
| 470 | + for i, lm in enumerate(loaded_models): |
| 471 | + logger.info(f"\nModel {i+1}/{len(loaded_models)}:") |
| 472 | + |
| 473 | + # Check lifecycle status |
| 474 | + mp = lm.model # weakref call to ModelPatcher |
| 475 | + if mp is None: |
| 476 | + # ModelPatcher is gone. Check if the underlying model is still alive (potential leak) |
| 477 | + if lm.is_dead() and lm.real_model() is not None: |
| 478 | + logger.warning(f" [!] Status: LEAK DETECTED (Patcher GC'd, but underlying model {lm.real_model().__class__.__name__} persists)") |
| 479 | + else: |
| 480 | + logger.info(f" Status: Cleaned Up (Patcher and Model GC'd)") |
| 481 | + continue |
| 482 | + |
| 483 | + model_id = create_model_identifier(mp) |
| 484 | + logger.info(f" Identifier: {model_id}") |
| 485 | + logger.info(f" Status: {'Active (In Use)' if lm.currently_used else 'Idle (Cache)'}") |
| 486 | + |
| 487 | + # A. Logical View (What ComfyUI intends/tracks) |
| 488 | + logger.info(" [A] Logical View (ComfyUI Tracking):") |
| 489 | + |
| 490 | + # Devices: Target (Compute) vs Offload (Storage) |
| 491 | + logger.info(f" Devices: Target={lm.device} | Offload={mp.offload_device} | Current (Model.device)={mp.current_loaded_device()}") |
| 492 | + |
| 493 | + # Memory Footprint |
| 494 | + mem_total = lm.model_memory() |
| 495 | + mem_loaded = lm.model_loaded_memory() |
| 496 | + mem_offloaded = lm.model_offloaded_memory() |
| 497 | + logger.info(f" Memory (MB): Total={mem_total/(1024**2):.2f} | Loaded (on Target)={mem_loaded/(1024**2):.2f} | Offloaded={mem_offloaded/(1024**2):.2f}") |
| 498 | + |
| 499 | + # Management Mode (LowVRAM/DisTorch) |
| 500 | + # model_lowvram indicates if ComfyUI is managing this model partially |
| 501 | + is_lowvram = getattr(mp.model, 'model_lowvram', False) |
| 502 | + lowvram_patches_pending = mp.lowvram_patch_counter() |
| 503 | + logger.info(f" Mode: {'Partial Load (LowVRAM/DisTorch)' if is_lowvram else 'Full Load'}") |
| 504 | + if is_lowvram: |
| 505 | + # This indicates how many weights are being managed by the partial loading system |
| 506 | + logger.info(f" Weights Managed by LowVRAM/DisTorch System: {lowvram_patches_pending}") |
| 507 | + |
| 508 | + # Patching (LoRAs, etc.) - Tracking Attach/Detach |
| 509 | + num_weight_patches = len(mp.patches) |
| 510 | + # Check the UUID applied to the actual weights vs the UUID defined in the patcher |
| 511 | + current_weight_uuid = getattr(mp.model, 'current_weight_patches_uuid', None) |
| 512 | + weights_synced = (mp.patches_uuid == current_weight_uuid) and (current_weight_uuid is not None) |
| 513 | + |
| 514 | + if num_weight_patches > 0: |
| 515 | + status = 'Applied & Synced' if weights_synced else 'Pending/Mismatch (Re-patch needed)' |
| 516 | + logger.info(f" Patches: {num_weight_patches} weight patches defined | Status: {status}") |
| 517 | + logger.info(f" UUIDs: Defined={str(mp.patches_uuid)[:8]}... | Applied={str(current_weight_uuid)[:8] if current_weight_uuid else 'None'}...") |
| 518 | + |
| 519 | + # B. Physical View (Ground Truth Tensor Locations) |
| 520 | + logger.info(" [B] Physical View (Ground Truth Tensor Locations):") |
| 521 | + device_summary, calculated_total_mem = analyze_tensor_locations(mp) |
| 522 | + |
| 523 | + if "error" in device_summary: |
| 524 | + logger.error(f" Analysis Error: {device_summary['error']}") |
| 525 | + continue |
| 526 | + |
| 527 | + if not device_summary: |
| 528 | + logger.info(" No tensors found (e.g., fully offloaded CLIP or utility object).") |
| 529 | + else: |
| 530 | + # Sort devices (CPU last) |
| 531 | + sorted_devices = sorted(device_summary.keys(), key=lambda d: (d.startswith("cpu"), d)) |
| 532 | + fmt_loc = " {:<15} | Tensors: {:>6} | Memory (MB): {:>10.2f} | Percent: {:>6.1f}%" |
| 533 | + for device in sorted_devices: |
| 534 | + data = device_summary[device] |
| 535 | + percent = (data['memory'] / calculated_total_mem) * 100 if calculated_total_mem > 0 else 0 |
| 536 | + logger.info(fmt_loc.format(device, data['tensors'], data['memory']/(1024**2), percent)) |
| 537 | + |
| 538 | + # Verification Check |
| 539 | + if abs(calculated_total_mem - mem_total) > (1024*1024): # Allow 1MB difference |
| 540 | + logger.warning(f" [!] Verification WARNING: Physical memory ({calculated_total_mem/(1024**2):.2f}MB) differs from logical memory ({mem_total/(1024**2):.2f}MB).") |
| 541 | + |
| 542 | + logger.info("-" * 100) |
| 543 | + |
| 544 | + logger.info("End of Inspection") |
| 545 | + logger.info("=" * 100) |
0 commit comments