Skip to content

Commit fb6e2e6

Browse files
committed
refactor(distorch): Implement IS_CHANGED for robust model reloading
This commit refactors the model loading logic to properly integrate with ComfyUI's caching system. - Implemented the `IS_CHANGED` class method, which creates a hash of the DisTorch-specific settings (e.g., `compute_device`, `virtual_vram_gb`). - This allows ComfyUI to automatically detect when settings have changed and trigger a model reload, invalidating the cache correctly. - Removed the previous manual and less reliable logic for unloading and reloading the model from within the `override` function. - Set the default log level to "Engineering" to provide more detailed output during development.
1 parent e288152 commit fb6e2e6

2 files changed

Lines changed: 32 additions & 28 deletions

File tree

__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# --- DisTorch V2 Logging Configuration ---
1111
# Set to "E" for Engineering (DEBUG) or "P" for Production (INFO)
12-
LOG_LEVEL = "P"
12+
LOG_LEVEL = "E"
1313

1414
# Configure logger
1515
log_level = logging.DEBUG if LOG_LEVEL == "E" else logging.INFO

distorch_2.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import hashlib
1111
import copy
12+
import inspect
1213
from collections import defaultdict
1314
import comfy.model_management as mm
1415
import comfy.model_patcher
@@ -441,6 +442,13 @@ def INPUT_TYPES(s):
441442
FUNCTION = "override"
442443
TITLE = f"{cls.TITLE if hasattr(cls, 'TITLE') else cls.__name__} (DisTorch2)"
443444

445+
@classmethod
446+
def IS_CHANGED(s, *args, compute_device=None, virtual_vram_gb=4.0,
447+
donor_device="cpu", expert_mode_allocations="", **kwargs):
448+
# Create a hash of our specific settings
449+
settings_str = f"{compute_device}{virtual_vram_gb}{donor_device}{expert_mode_allocations}"
450+
return hashlib.sha256(settings_str.encode()).hexdigest()
451+
444452
def override(self, *args, compute_device=None, virtual_vram_gb=4.0,
445453
donor_device="cpu", expert_mode_allocations="", **kwargs):
446454
from . import set_current_device
@@ -449,57 +457,53 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0,
449457

450458
# Register our patched ModelPatcher
451459
register_patched_safetensor_modelpatcher()
452-
453-
# Build allocation string - EXACTLY like GGUF
454-
vram_string = ""
455-
if virtual_vram_gb > 0:
456-
vram_string = f"{compute_device};{virtual_vram_gb};{donor_device}"
457-
458-
full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else ""
459460

460-
# --- Force Model Reload on Setting Change ---
461-
# Create a hash of the DisTorch settings
462-
settings_str = f"{compute_device}{virtual_vram_gb}{donor_device}{expert_mode_allocations}"
463-
settings_hash = hashlib.sha256(settings_str.encode()).hexdigest()[:8]
464-
465-
# Temporarily load the model to get its hash, without applying our patch yet
461+
# Call original function
466462
fn = getattr(super(), cls.FUNCTION)
467-
temp_out = fn(*args, **kwargs)
468463

464+
# --- Check if we need to unload the model due to settings change ---
465+
# This logic is a bit redundant with IS_CHANGED, but provides clear logging
466+
settings_str = f"{compute_device}{virtual_vram_gb}{donor_device}{expert_mode_allocations}"
467+
settings_hash = hashlib.sha256(settings_str.encode()).hexdigest()
468+
469+
# Temporarily load to get hash without applying our patch
470+
temp_out = fn(*args, **kwargs)
469471
model_to_check = None
470472
if hasattr(temp_out[0], 'model'):
471473
model_to_check = temp_out[0]
472474
elif hasattr(temp_out[0], 'patcher') and hasattr(temp_out[0].patcher, 'model'):
473475
model_to_check = temp_out[0].patcher
474-
476+
475477
if model_to_check:
476478
model_hash = create_safetensor_model_hash(model_to_check, "override_check")
477-
478479
last_settings_hash = safetensor_settings_store.get(model_hash)
479480

480481
if last_settings_hash != settings_hash:
481-
logging.info(f"[MultiGPU_DisTorch2] Settings changed for model {model_hash[:8]}. Forcing reload.")
482-
mm.unload_model(model_to_check)
483-
# Update the settings store *before* reloading
484-
safetensor_settings_store[model_hash] = settings_hash
485-
# Call the loader again now that the model is unloaded
486-
out = fn(*args, **kwargs)
482+
logging.info(f"[MultiGPU_DisTorch2] Settings changed for model {model_hash[:8]}. Previous settings hash: {last_settings_hash}, New settings hash: {settings_hash}. Forcing reload.")
483+
# The IS_CHANGED mechanism should handle the reload, this is for logging.
487484
else:
488-
out = temp_out # Use the already loaded model
489-
else:
490-
out = temp_out # Should not happen, but as a fallback
485+
logging.info(f"[MultiGPU_DisTorch2] Settings unchanged for model {model_hash[:8]}. Using cached model.")
486+
487+
out = fn(*args, **kwargs)
491488

489+
# Build allocation string - EXACTLY like GGUF
490+
vram_string = ""
491+
if virtual_vram_gb > 0:
492+
vram_string = f"{compute_device};{virtual_vram_gb};{donor_device}"
493+
494+
full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else ""
495+
492496
logging.info(f"[MULTIGPU_DISTORCHV2] Full allocation string: {full_allocation}")
493497

494498
# Store allocation for the model - EXACTLY like GGUF
495499
if hasattr(out[0], 'model'):
496500
model_hash = create_safetensor_model_hash(out[0], "override")
497501
safetensor_allocation_store[model_hash] = full_allocation
498-
safetensor_settings_store[model_hash] = settings_hash # Ensure it's set
502+
safetensor_settings_store[model_hash] = settings_hash
499503
elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'):
500504
model_hash = create_safetensor_model_hash(out[0].patcher, "override")
501505
safetensor_allocation_store[model_hash] = full_allocation
502-
safetensor_settings_store[model_hash] = settings_hash # Ensure it's set
506+
safetensor_settings_store[model_hash] = settings_hash
503507

504508
return out
505509

0 commit comments

Comments
 (0)