Skip to content

Commit dde3480

Browse files
authored
Merge pull request #117 - refactor: Replace keep_loaded with eject_models Boolean for improved memory management
2 parents a1b7b1f + 0919fd4 commit dde3480

22 files changed

Lines changed: 215 additions & 250 deletions

__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
WEB_DIRECTORY = "./web"
24-
MGPU_MM_LOG = False
24+
MGPU_MM_LOG = False
2525
DEBUG_LOG = False
2626

2727
logger = logging.getLogger("MultiGPU")
@@ -85,12 +85,12 @@ def text_encoder_device_patched():
8585
else:
8686
devs = set(get_device_list())
8787
device = torch.device(current_text_encoder_device) if str(current_text_encoder_device) in devs else torch.device("cpu")
88-
logger.debug(f"[MultiGPU Core Patching] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
88+
logger.info(f"[MultiGPU Core Patching] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})")
8989
return device
9090

9191
logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device and mm.text_encoder_device")
92-
logger.debug(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
93-
logger.debug(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
92+
logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
93+
logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
9494
mm.get_torch_device = get_torch_device_patched
9595
mm.text_encoder_device = text_encoder_device_patched
9696

distorch_2.py

Lines changed: 138 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -58,51 +58,149 @@ def register_patched_safetensor_modelpatcher():
5858
# Patch ComfyUI's ModelPatcher
5959
if not hasattr(comfy.model_patcher.ModelPatcher, '_distorch_patched'):
6060

61-
# Patch LoadedModel.model_memory_required to drive behavior purely by Phase 2 = unload_distorch_model flag
62-
from comfy.model_management import current_loaded_models
6361

64-
original_loaded_model_memory_required = None
65-
for cls in current_loaded_models.__class__.__mro__:
66-
if hasattr(cls, 'model_memory_required'):
67-
original_loaded_model_memory_required = cls.model_memory_required
68-
break
69-
70-
if original_loaded_model_memory_required is None:
71-
# Global patch of LoadedModel class if available
72-
import comfy.model_management as mm
73-
74-
original_loaded_model_memory_required = mm.LoadedModel.model_memory_required
75-
76-
def patched_loaded_model_memory_required(self, device):
77-
"""Drive unload behavior purely by unload_distorch_model flag"""
78-
multigpu_memory_log("unload_distorch_model_memory_check", "start")
79-
logger.mgpu_mm_log(f"[IS_DISTORCH_MODEL] Memory assessment requested for model on device: {device}")
80-
81-
# Check if this is a DisTorch model with unload_distorch_model flag
82-
is_distorch_model = hasattr(getattr(getattr(self, 'model', None), 'model', None), '_mgpu_unload_distorch_model')
62+
# PATCH load_models_gpu with correct memory calculations per model flags
63+
original_load_models_gpu = mm.load_models_gpu
8364

84-
model_name = type(getattr(getattr(self, 'model', None), 'model', None)).__name__ if getattr(getattr(self, 'model', None), 'model', None) else "Unknown"
85-
logger.mgpu_mm_log(f"[IS_DISTORCH_MODEL] DisTorch model: {model_name}, is_distorch_model={is_distorch_model}")
65+
def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
66+
from comfy.model_management import cleanup_models_gc, get_free_memory, free_memory, current_loaded_models
67+
from comfy.model_management import VRAMState, vram_state, lowvram_available, MIN_WEIGHT_MEMORY_RATIO
68+
from comfy.model_management import minimum_inference_memory, extra_reserved_memory, is_device_cpu
69+
70+
multigpu_memory_log("load_models_gpu_top_level", "start")
71+
72+
cleanup_models_gc()
73+
74+
inference_memory = minimum_inference_memory()
75+
extra_reserved_mem = extra_reserved_memory()
76+
memory_required_total = memory_required + extra_reserved_mem
77+
extra_mem = max(inference_memory, memory_required_total)
78+
if minimum_memory_required is None:
79+
minimum_memory_required = extra_mem
80+
else:
81+
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_mem)
82+
83+
models_temp = set()
84+
for m in models:
85+
models_temp.add(m)
86+
for mm_patch in m.model_patches_models():
87+
models_temp.add(mm_patch)
88+
89+
models = models_temp
90+
91+
models_to_load = []
92+
93+
for x in models:
94+
loaded_model = mm.LoadedModel(x)
95+
try:
96+
loaded_model_index = current_loaded_models.index(loaded_model)
97+
except:
98+
loaded_model_index = None
99+
100+
if loaded_model_index is not None:
101+
loaded = current_loaded_models[loaded_model_index]
102+
loaded.currently_used = True
103+
models_to_load.append(loaded)
104+
else:
105+
if hasattr(x, "model"):
106+
logging.info(f"Requested to load {x.model.__class__.__name__}")
107+
models_to_load.append(loaded_model)
108+
109+
for loaded_model in models_to_load:
110+
to_unload = []
111+
for i in range(len(current_loaded_models)):
112+
if loaded_model.model.is_clone(current_loaded_models[i].model):
113+
to_unload = [i] + to_unload
114+
for i in to_unload:
115+
model_to_unload = current_loaded_models.pop(i)
116+
model_to_unload.model.detach(unpatch_all=False)
117+
model_to_unload.model_finalizer.detach()
118+
119+
# DisTorch Processing
120+
total_memory_required = {}
121+
eject_device = None
122+
123+
for loaded_model in models_to_load:
124+
device = loaded_model.device
125+
base_memory = loaded_model.model_memory_required(device)
126+
127+
# Check DisTorch flags
128+
is_distorch = hasattr(loaded_model.model.model, '_mgpu_virtual_vram_gb')
129+
has_eject = hasattr(loaded_model.model.model, '_mgpu_eject_models')
130+
131+
if has_eject:
132+
eject_device = device
133+
logger.mgpu_mm_log("DisTorch eject_models=True, is_distorch=True - MAX memory eviction")
134+
135+
if is_distorch:
136+
# is_distorch=True: use compute device allocation size
137+
virtual_vram_gb = loaded_model.model.model._mgpu_virtual_vram_gb
138+
virtual_vram_bytes = virtual_vram_gb * (1024**3)
139+
adjusted_memory = max(0, base_memory - virtual_vram_bytes)
140+
total_memory_required[device] = total_memory_required.get(device, 0) + adjusted_memory
141+
logger.mgpu_mm_log(f"DisTorch is_distorch=True, model adjusted {(base_memory - virtual_vram_bytes)/(1024**3):.2f}GB for device {device}")
142+
else:
143+
# is_distorch=False: use full model size
144+
total_memory_required[device] = total_memory_required.get(device, 0) + base_memory
145+
logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] Standard model {(base_memory)/(1024**3):.2f}GB for device {device}")
146+
147+
for device in total_memory_required:
148+
if device != torch.device("cpu"):
149+
requested_mem = total_memory_required[device] * 1.1 + extra_mem
150+
logger.mgpu_mm_log(f"[FREE_MEMORY_CALL] Device {device}: requesting {requested_mem/(1024**3):.2f}GB = {total_memory_required[device]/(1024**3):.2f}GB * 1.1 + {extra_mem/(1024**3):.2f}GB inference")
151+
152+
153+
multigpu_memory_log("free_memory", "pre")
86154

87-
if is_distorch_model:
88-
if self.model.model._mgpu_unload_distorch_model:
155+
for device in total_memory_required:
156+
if device != torch.device("cpu"):
157+
if device == eject_device:
89158
total_device_memory = mm.get_total_memory(device)
90-
memory_gb = total_device_memory / (1024**3)
91-
logger.mgpu_mm_log(f"[IS_DISTORCH_MODEL] _mgpu_unload_distorch_model=True - Reporting MAX memory ({memory_gb:.2f}GB) to force complete eviction")
92-
return total_device_memory
159+
logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] eject_models=1, is_distorch=1 → using MAX memory ({total_device_memory/(1024**3):.2f}GB) for eviction")
160+
free_memory(total_device_memory,device)
93161
else:
94-
logger.mgpu_mm_log("[IS_DISTORCH_MODEL] _mgpu_unload_distorch_model=False - Reporting 0 bytes (prevents eviction)")
95-
return 0
96-
97-
# Not a DisTorch model - use original behavior
98-
logger.mgpu_mm_log("[IS_DISTORCH_MODEL] Non-DisTorch model - Using original Comfy memory calculation")
99-
original_result = original_loaded_model_memory_required(self, device)
100-
original_gb = original_result / (1024**3) if original_result else 0
101-
logger.mgpu_mm_log(f"[IS_DISTORCH_MODEL] Original calculation returned: {original_gb:.2f}GB")
102-
multigpu_memory_log("keep_loaded_memory_check", "end")
103-
return original_result
104-
105-
mm.LoadedModel.model_memory_required = patched_loaded_model_memory_required
162+
logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] eject_models=0, using Comfy Core Computed memory ({(total_memory_required[device] * 1.1 + extra_mem)/(1024**3):.2f}GB) for eviction")
163+
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
164+
165+
multigpu_memory_log("free_memory/minimum_memory_required", "post/pre")
166+
167+
for device in total_memory_required:
168+
if device != torch.device("cpu"):
169+
free_mem = get_free_memory(device)
170+
free_mem_gb = free_mem / (1024**3)
171+
min_required_gb = minimum_memory_required / (1024**3)
172+
logger.mgpu_mm_log(f"[MIN_MEMORY_CHECK] Device {device}: free={free_mem_gb:.2f}GB, required={min_required_gb:.2f}GB, will_evict={free_mem < minimum_memory_required}")
173+
174+
if free_mem < minimum_memory_required:
175+
models_l = free_memory(minimum_memory_required, device)
176+
logger.mgpu_mm_log(f"[EVICTION] Device {device}: unloaded {len(models_l)} models due to insufficient memory")
177+
logging.info("{} models unloaded.".format(len(models_l)))
178+
179+
multigpu_memory_log("minimum_memory_required", "post")
180+
181+
for loaded_model in models_to_load:
182+
model = loaded_model.model
183+
torch_dev = model.load_device
184+
if is_device_cpu(torch_dev):
185+
vram_set_state = VRAMState.DISABLED
186+
else:
187+
vram_set_state = vram_state
188+
lowvram_model_memory = 0
189+
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
190+
loaded_memory = loaded_model.model_loaded_memory()
191+
current_free_mem = get_free_memory(torch_dev) + loaded_memory
192+
193+
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
194+
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
195+
196+
if vram_set_state == VRAMState.NO_VRAM:
197+
lowvram_model_memory = 0.1
198+
199+
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
200+
current_loaded_models.insert(0, loaded_model)
201+
202+
# Replace the module function
203+
mm.load_models_gpu = patched_load_models_gpu
106204

107205
original_partially_load = comfy.model_patcher.ModelPatcher.partially_load
108206

0 commit comments

Comments
 (0)