@@ -58,51 +58,149 @@ def register_patched_safetensor_modelpatcher():
5858 # Patch ComfyUI's ModelPatcher
5959 if not hasattr (comfy .model_patcher .ModelPatcher , '_distorch_patched' ):
6060
61- # Patch LoadedModel.model_memory_required to drive behavior purely by Phase 2 = unload_distorch_model flag
62- from comfy .model_management import current_loaded_models
6361
64- original_loaded_model_memory_required = None
65- for cls in current_loaded_models .__class__ .__mro__ :
66- if hasattr (cls , 'model_memory_required' ):
67- original_loaded_model_memory_required = cls .model_memory_required
68- break
69-
70- if original_loaded_model_memory_required is None :
71- # Global patch of LoadedModel class if available
72- import comfy .model_management as mm
73-
74- original_loaded_model_memory_required = mm .LoadedModel .model_memory_required
75-
76- def patched_loaded_model_memory_required (self , device ):
77- """Drive unload behavior purely by unload_distorch_model flag"""
78- multigpu_memory_log ("unload_distorch_model_memory_check" , "start" )
79- logger .mgpu_mm_log (f"[IS_DISTORCH_MODEL] Memory assessment requested for model on device: { device } " )
80-
81- # Check if this is a DisTorch model with unload_distorch_model flag
82- is_distorch_model = hasattr (getattr (getattr (self , 'model' , None ), 'model' , None ), '_mgpu_unload_distorch_model' )
62+ # PATCH load_models_gpu with correct memory calculations per model flags
63+ original_load_models_gpu = mm .load_models_gpu
8364
84- model_name = type (getattr (getattr (self , 'model' , None ), 'model' , None )).__name__ if getattr (getattr (self , 'model' , None ), 'model' , None ) else "Unknown"
85- logger .mgpu_mm_log (f"[IS_DISTORCH_MODEL] DisTorch model: { model_name } , is_distorch_model={ is_distorch_model } " )
65+ def patched_load_models_gpu (models , memory_required = 0 , force_patch_weights = False , minimum_memory_required = None , force_full_load = False ):
66+ from comfy .model_management import cleanup_models_gc , get_free_memory , free_memory , current_loaded_models
67+ from comfy .model_management import VRAMState , vram_state , lowvram_available , MIN_WEIGHT_MEMORY_RATIO
68+ from comfy .model_management import minimum_inference_memory , extra_reserved_memory , is_device_cpu
69+
70+ multigpu_memory_log ("load_models_gpu_top_level" , "start" )
71+
72+ cleanup_models_gc ()
73+
74+ inference_memory = minimum_inference_memory ()
75+ extra_reserved_mem = extra_reserved_memory ()
76+ memory_required_total = memory_required + extra_reserved_mem
77+ extra_mem = max (inference_memory , memory_required_total )
78+ if minimum_memory_required is None :
79+ minimum_memory_required = extra_mem
80+ else :
81+ minimum_memory_required = max (inference_memory , minimum_memory_required + extra_reserved_mem )
82+
83+ models_temp = set ()
84+ for m in models :
85+ models_temp .add (m )
86+ for mm_patch in m .model_patches_models ():
87+ models_temp .add (mm_patch )
88+
89+ models = models_temp
90+
91+ models_to_load = []
92+
93+ for x in models :
94+ loaded_model = mm .LoadedModel (x )
95+ try :
96+ loaded_model_index = current_loaded_models .index (loaded_model )
97+ except :
98+ loaded_model_index = None
99+
100+ if loaded_model_index is not None :
101+ loaded = current_loaded_models [loaded_model_index ]
102+ loaded .currently_used = True
103+ models_to_load .append (loaded )
104+ else :
105+ if hasattr (x , "model" ):
106+ logging .info (f"Requested to load { x .model .__class__ .__name__ } " )
107+ models_to_load .append (loaded_model )
108+
109+ for loaded_model in models_to_load :
110+ to_unload = []
111+ for i in range (len (current_loaded_models )):
112+ if loaded_model .model .is_clone (current_loaded_models [i ].model ):
113+ to_unload = [i ] + to_unload
114+ for i in to_unload :
115+ model_to_unload = current_loaded_models .pop (i )
116+ model_to_unload .model .detach (unpatch_all = False )
117+ model_to_unload .model_finalizer .detach ()
118+
119+ # DisTorch Processing
120+ total_memory_required = {}
121+ eject_device = None
122+
123+ for loaded_model in models_to_load :
124+ device = loaded_model .device
125+ base_memory = loaded_model .model_memory_required (device )
126+
127+ # Check DisTorch flags
128+ is_distorch = hasattr (loaded_model .model .model , '_mgpu_virtual_vram_gb' )
129+ has_eject = hasattr (loaded_model .model .model , '_mgpu_eject_models' )
130+
131+ if has_eject :
132+ eject_device = device
133+ logger .mgpu_mm_log ("DisTorch eject_models=True, is_distorch=True - MAX memory eviction" )
134+
135+ if is_distorch :
136+ # is_distorch=True: use compute device allocation size
137+ virtual_vram_gb = loaded_model .model .model ._mgpu_virtual_vram_gb
138+ virtual_vram_bytes = virtual_vram_gb * (1024 ** 3 )
139+ adjusted_memory = max (0 , base_memory - virtual_vram_bytes )
140+ total_memory_required [device ] = total_memory_required .get (device , 0 ) + adjusted_memory
141+ logger .mgpu_mm_log (f"DisTorch is_distorch=True, model adjusted { (base_memory - virtual_vram_bytes )/ (1024 ** 3 ):.2f} GB for device { device } " )
142+ else :
143+ # is_distorch=False: use full model size
144+ total_memory_required [device ] = total_memory_required .get (device , 0 ) + base_memory
145+ logger .mgpu_mm_log (f"[LOAD_MODELS_GPU] Standard model { (base_memory )/ (1024 ** 3 ):.2f} GB for device { device } " )
146+
147+ for device in total_memory_required :
148+ if device != torch .device ("cpu" ):
149+ requested_mem = total_memory_required [device ] * 1.1 + extra_mem
150+ logger .mgpu_mm_log (f"[FREE_MEMORY_CALL] Device { device } : requesting { requested_mem / (1024 ** 3 ):.2f} GB = { total_memory_required [device ]/ (1024 ** 3 ):.2f} GB * 1.1 + { extra_mem / (1024 ** 3 ):.2f} GB inference" )
151+
152+
153+ multigpu_memory_log ("free_memory" , "pre" )
86154
87- if is_distorch_model :
88- if self .model .model ._mgpu_unload_distorch_model :
155+ for device in total_memory_required :
156+ if device != torch .device ("cpu" ):
157+ if device == eject_device :
89158 total_device_memory = mm .get_total_memory (device )
90- memory_gb = total_device_memory / (1024 ** 3 )
91- logger .mgpu_mm_log (f"[IS_DISTORCH_MODEL] _mgpu_unload_distorch_model=True - Reporting MAX memory ({ memory_gb :.2f} GB) to force complete eviction" )
92- return total_device_memory
159+ logger .mgpu_mm_log (f"[LOAD_MODELS_GPU] eject_models=1, is_distorch=1 → using MAX memory ({ total_device_memory / (1024 ** 3 ):.2f} GB) for eviction" )
160+ free_memory (total_device_memory ,device )
93161 else :
94- logger .mgpu_mm_log ("[IS_DISTORCH_MODEL] _mgpu_unload_distorch_model=False - Reporting 0 bytes (prevents eviction)" )
95- return 0
96-
97- # Not a DisTorch model - use original behavior
98- logger .mgpu_mm_log ("[IS_DISTORCH_MODEL] Non-DisTorch model - Using original Comfy memory calculation" )
99- original_result = original_loaded_model_memory_required (self , device )
100- original_gb = original_result / (1024 ** 3 ) if original_result else 0
101- logger .mgpu_mm_log (f"[IS_DISTORCH_MODEL] Original calculation returned: { original_gb :.2f} GB" )
102- multigpu_memory_log ("keep_loaded_memory_check" , "end" )
103- return original_result
104-
105- mm .LoadedModel .model_memory_required = patched_loaded_model_memory_required
162+ logger .mgpu_mm_log (f"[LOAD_MODELS_GPU] eject_models=0, using Comfy Core Computed memory ({ (total_memory_required [device ] * 1.1 + extra_mem )/ (1024 ** 3 ):.2f} GB) for eviction" )
163+ free_memory (total_memory_required [device ] * 1.1 + extra_mem , device )
164+
165+ multigpu_memory_log ("free_memory/minimum_memory_required" , "post/pre" )
166+
167+ for device in total_memory_required :
168+ if device != torch .device ("cpu" ):
169+ free_mem = get_free_memory (device )
170+ free_mem_gb = free_mem / (1024 ** 3 )
171+ min_required_gb = minimum_memory_required / (1024 ** 3 )
172+ logger .mgpu_mm_log (f"[MIN_MEMORY_CHECK] Device { device } : free={ free_mem_gb :.2f} GB, required={ min_required_gb :.2f} GB, will_evict={ free_mem < minimum_memory_required } " )
173+
174+ if free_mem < minimum_memory_required :
175+ models_l = free_memory (minimum_memory_required , device )
176+ logger .mgpu_mm_log (f"[EVICTION] Device { device } : unloaded { len (models_l )} models due to insufficient memory" )
177+ logging .info ("{} models unloaded." .format (len (models_l )))
178+
179+ multigpu_memory_log ("minimum_memory_required" , "post" )
180+
181+ for loaded_model in models_to_load :
182+ model = loaded_model .model
183+ torch_dev = model .load_device
184+ if is_device_cpu (torch_dev ):
185+ vram_set_state = VRAMState .DISABLED
186+ else :
187+ vram_set_state = vram_state
188+ lowvram_model_memory = 0
189+ if lowvram_available and (vram_set_state == VRAMState .LOW_VRAM or vram_set_state == VRAMState .NORMAL_VRAM ) and not force_full_load :
190+ loaded_memory = loaded_model .model_loaded_memory ()
191+ current_free_mem = get_free_memory (torch_dev ) + loaded_memory
192+
193+ lowvram_model_memory = max (128 * 1024 * 1024 , (current_free_mem - minimum_memory_required ), min (current_free_mem * MIN_WEIGHT_MEMORY_RATIO , current_free_mem - minimum_inference_memory ()))
194+ lowvram_model_memory = max (0.1 , lowvram_model_memory - loaded_memory )
195+
196+ if vram_set_state == VRAMState .NO_VRAM :
197+ lowvram_model_memory = 0.1
198+
199+ loaded_model .model_load (lowvram_model_memory , force_patch_weights = force_patch_weights )
200+ current_loaded_models .insert (0 , loaded_model )
201+
202+ # Replace the module function
203+ mm .load_models_gpu = patched_load_models_gpu
106204
107205 original_partially_load = comfy .model_patcher .ModelPatcher .partially_load
108206
0 commit comments