Skip to content

Commit 0b438f7

Browse files
committed
Refactor code for improved readability and performance
- Cleaned up unnecessary whitespace and comments in model_management_mgpu.py, nodes.py, wanvideo.py, and wrappers.py for better code clarity. - Replaced list comprehensions with direct list conversions in nodes.py for efficiency. - Updated memory logging format in model_management_mgpu.py to streamline data capture. - Enhanced device management in wanvideo.py by ensuring consistent device setting and loading. - Added linting configurations in pyproject.toml to enforce code quality standards. - Removed unused imports and optimized existing ones across multiple files.
1 parent 24e4e17 commit 0b438f7

12 files changed

Lines changed: 354 additions & 277 deletions

__init__.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import torch
22
import logging
3-
import weakref
43
import os
5-
import copy
64
import json
75
import importlib
86
from contextlib import contextmanager
97
from datetime import datetime
108
from pathlib import Path
9+
from types import MethodType
1110
import folder_paths
1211
import comfy.model_management as mm
1312
import comfy.memory_management
@@ -17,13 +16,13 @@
1716
from .device_utils import (
1817
get_device_list,
1918
is_accelerator_available,
20-
soft_empty_cache_multigpu,
19+
soft_empty_cache_multigpu as soft_empty_cache_multigpu,
2120
)
2221
from .model_management_mgpu import (
23-
trigger_executor_cache_reset,
24-
check_cpu_memory_threshold,
25-
multigpu_memory_log,
26-
force_full_system_cleanup,
22+
trigger_executor_cache_reset as trigger_executor_cache_reset,
23+
check_cpu_memory_threshold as check_cpu_memory_threshold,
24+
multigpu_memory_log as multigpu_memory_log,
25+
force_full_system_cleanup as force_full_system_cleanup,
2726
)
2827

2928
WEB_DIRECTORY = "./web"
@@ -138,7 +137,7 @@ def mgpu_mm_log_method(self, msg):
138137
f"[MultiGPU Model Management] {msg}",
139138
extra={"mgpu_context": {"component": "model_management"}},
140139
)
141-
logger.mgpu_mm_log = mgpu_mm_log_method.__get__(logger, type(logger))
140+
logger.mgpu_mm_log = MethodType(mgpu_mm_log_method, logger)
142141

143142
def _normalize_module_name(module_name):
144143
"""Normalize a custom node directory name for tolerant matching."""
@@ -416,7 +415,7 @@ def wrap_for_dlpack_with_device_guard(*args, **kwargs):
416415
logger.info("[MultiGPU] Applied comfy_kitchen CUDA DLPack device guard patch")
417416
return True
418417

419-
logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")
418+
logger.info("[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")
420419
logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
421420
logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
422421
logger.info(f"[MultiGPU DEBUG] Initial current_unet_offload_device: {current_unet_offload_device}")
@@ -457,18 +456,18 @@ def wrap_for_dlpack_with_device_guard(*args, **kwargs):
457456
override_class_clip,
458457
override_class_clip_no_device,
459458
override_class_with_distorch_gguf,
460-
override_class_with_distorch_gguf_v2,
459+
override_class_with_distorch_gguf_v2 as override_class_with_distorch_gguf_v2,
461460
override_class_with_distorch_clip,
462461
override_class_with_distorch_clip_no_device,
463-
override_class_with_distorch,
462+
override_class_with_distorch as override_class_with_distorch,
464463
override_class_with_distorch_safetensor_v2,
465464
override_class_with_distorch_safetensor_v2_clip,
466465
override_class_with_distorch_safetensor_v2_clip_no_device,
467466
)
468467
from .distorch_2 import (
469-
register_patched_safetensor_modelpatcher,
470-
analyze_safetensor_loading,
471-
calculate_safetensor_vvram_allocation,
468+
register_patched_safetensor_modelpatcher as register_patched_safetensor_modelpatcher,
469+
analyze_safetensor_loading as analyze_safetensor_loading,
470+
calculate_safetensor_vvram_allocation as calculate_safetensor_vvram_allocation,
472471
)
473472

474473
from .checkpoint_multigpu import (
@@ -569,7 +568,7 @@ def register_and_count(module_names, node_map):
569568
if check_module_exists(name):
570569
found = True
571570
break
572-
571+
573572
count = 0
574573
if found:
575574
try:
@@ -582,7 +581,7 @@ def register_and_count(module_names, node_map):
582581
for key, value in resolved_node_map.items():
583582
NODE_CLASS_MAPPINGS[key] = value
584583
count = len(NODE_CLASS_MAPPINGS) - initial_len
585-
584+
586585
registration_data.append({"name": module_names[0], "found": "Y" if found else "N", "count": count})
587586
return found
588587

checkpoint_multigpu.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
def patch_load_state_dict_guess_config():
2222
"""Monkey patch comfy.sd.load_state_dict_guess_config with MultiGPU-aware checkpoint loading."""
2323
global original_load_state_dict_guess_config
24-
24+
2525
if original_load_state_dict_guess_config is not None:
2626
logger.debug("[MultiGPU Checkpoint] load_state_dict_guess_config is already patched.")
2727
return
28-
28+
2929
logger.info("[MultiGPU Core Patching] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading.")
3030
original_load_state_dict_guess_config = comfy.sd.load_state_dict_guess_config
3131
comfy.sd.load_state_dict_guess_config = patched_load_state_dict_guess_config
@@ -35,7 +35,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
3535
te_model_options={}, metadata=None):
3636
"""Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
3737
from . import set_current_device, set_current_text_encoder_device, get_current_device, get_current_text_encoder_device
38-
38+
3939
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
4040
config_hash = str(sd_size)
4141
device_config = checkpoint_device_config.get(config_hash)
@@ -53,7 +53,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
5353
vae = None
5454
model = None
5555
model_patcher = None
56-
56+
5757
# Capture the current devices at runtime so we can restore them after loading
5858
original_main_device = get_current_device()
5959
original_clip_device = get_current_text_encoder_device()
@@ -68,7 +68,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
6868
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
6969

7070
model_config = comfy.model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
71-
71+
7272
if model_config is None:
7373
logger.warning("[MultiGPU] Warning: Not a standard checkpoint file. Trying to load as diffusion model only.")
7474
# Simplified fallback for non-checkpoints
@@ -83,13 +83,13 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
8383
unet_weight_dtype = list(model_config.supported_inference_dtypes)
8484
if model_config.scaled_fp8 is not None:
8585
weight_dtype = None
86-
86+
8787
if custom_operations is not None:
8888
model_config.custom_operations = custom_operations
8989
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
9090
if unet_dtype is None:
9191
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
92-
92+
9393
unet_compute_device = device_config.get('unet_device', original_main_device)
9494
if model_config.scaled_fp8 is not None:
9595
manual_cast_dtype = mm.unet_manual_cast(None, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
@@ -104,7 +104,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
104104

105105
if output_model:
106106
unet_compute_device = device_config.get('unet_device', original_main_device)
107-
set_current_device(unet_compute_device)
107+
set_current_device(unet_compute_device)
108108
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)
109109

110110
multigpu_memory_log(f"unet:{config_hash[:8]}", "pre-load")
@@ -131,7 +131,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
131131
vae_target_device = torch.device(device_config.get('vae_device', original_main_device))
132132
set_current_device(vae_target_device) # Use main device context for VAE
133133
multigpu_memory_log(f"vae:{config_hash[:8]}", "pre-load")
134-
134+
135135
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
136136
vae_sd = model_config.process_vae_state_dict(vae_sd)
137137
vae = VAE(sd=vae_sd, metadata=metadata)
@@ -151,7 +151,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
151151
for pref in scaled_fp8_list:
152152
skip = skip or k.startswith(pref)
153153
if not skip:
154-
out_sd[k] = sd[k]
154+
out_sd[k] = sd[k]
155155

156156
for pref in scaled_fp8_list:
157157
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
@@ -161,7 +161,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
161161

162162
clip_target_device = device_config.get('clip_device', original_clip_device)
163163
set_current_text_encoder_device(clip_target_device)
164-
164+
165165
clip_target = model_config.clip_target(state_dict=sd)
166166
if clip_target is not None:
167167
clip_sd = model_config.process_clip_state_dict(sd)
@@ -182,15 +182,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
182182
clip.patcher.model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)
183183

184184
m, u = clip.load_sd(clip_sd, full_model=True) # This respects the patched text_encoder_device
185-
if len(m) > 0: logger.warning(f"CLIP missing keys: {m}")
186-
if len(u) > 0: logger.debug(f"CLIP unexpected keys: {u}")
185+
if len(m) > 0:
186+
logger.warning(f"CLIP missing keys: {m}")
187+
if len(u) > 0:
188+
logger.debug(f"CLIP unexpected keys: {u}")
187189
logger.info("CLIP Loaded.")
188190
multigpu_memory_log(f"clip:{config_hash[:8]}", "post-load")
189191
else:
190192
logger.warning("No CLIP/text encoder weights in checkpoint.")
191193
else:
192194
logger.warning("CLIP target not found in model config.")
193-
195+
194196
finally:
195197
set_current_device(original_main_device)
196198
set_current_text_encoder_device(original_clip_device)
@@ -206,7 +208,7 @@ def INPUT_TYPES(s):
206208
import folder_paths
207209
devices = get_device_list()
208210
default_device = devices[1] if len(devices) > 1 else devices[0]
209-
211+
210212
return {
211213
"required": {
212214
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
@@ -215,27 +217,27 @@ def INPUT_TYPES(s):
215217
"vae_device": (devices, {"default": default_device}),
216218
}
217219
}
218-
220+
219221
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
220222
FUNCTION = "load_checkpoint"
221223
CATEGORY = "multigpu"
222224
TITLE = "Checkpoint Loader Advanced (MultiGPU)"
223-
225+
224226
def load_checkpoint(self, ckpt_name, unet_device, clip_device, vae_device):
225227
patch_load_state_dict_guess_config()
226-
228+
227229
import folder_paths
228230
import comfy.utils
229-
231+
230232
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
231233
sd = comfy.utils.load_torch_file(ckpt_path)
232234
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
233235
config_hash = str(sd_size)
234-
236+
235237
checkpoint_device_config[config_hash] = {
236238
'unet_device': unet_device, 'clip_device': clip_device, 'vae_device': vae_device
237239
}
238-
240+
239241
# Load using standard loader, our patch will intercept
240242
from nodes import CheckpointLoaderSimple
241243
return CheckpointLoaderSimple().load_checkpoint(ckpt_name)
@@ -247,7 +249,7 @@ def INPUT_TYPES(s):
247249
import folder_paths
248250
devices = get_device_list()
249251
compute_device = devices[1] if len(devices) > 1 else devices[0]
250-
252+
251253
return {
252254
"required": {
253255
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
@@ -265,18 +267,18 @@ def INPUT_TYPES(s):
265267
"eject_models": ("BOOLEAN", {"default": True}),
266268
}
267269
}
268-
270+
269271
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
270272
FUNCTION = "load_checkpoint"
271273
CATEGORY = "multigpu/distorch_2"
272274
TITLE = "Checkpoint Loader Advanced (DisTorch2)"
273-
275+
274276
def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, unet_donor_device,
275277
clip_compute_device, clip_virtual_vram_gb, clip_donor_device, vae_device,
276278
unet_expert_mode_allocations="", clip_expert_mode_allocations="", high_precision_loras=True, eject_models=True):
277-
279+
278280
if eject_models:
279-
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction")
281+
logger.mgpu_mm_log("[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction")
280282
ejection_count = 0
281283
for i, lm in enumerate(mm.current_loaded_models):
282284
model_name = type(getattr(lm.model, 'model', lm.model)).__name__ if lm.model else 'Unknown'
@@ -289,17 +291,17 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
289291
logger.mgpu_mm_log(f"[EJECT_MARKED] Model {i}: {model_name} (direct patcher) → marked for eviction")
290292
ejection_count += 1
291293
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP_COMPLETE] Marked {ejection_count} models for Comfy Core eviction during load_models_gpu")
292-
293-
patch_load_state_dict_guess_config()
294-
294+
295+
patch_load_state_dict_guess_config()
296+
295297
import folder_paths
296298
import comfy.utils
297-
299+
298300
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
299301
sd = comfy.utils.load_torch_file(ckpt_path)
300302
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
301303
config_hash = str(sd_size)
302-
304+
303305
checkpoint_device_config[config_hash] = {
304306
'unet_device': unet_compute_device,
305307
'clip_device': clip_compute_device,
@@ -312,7 +314,7 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
312314
elif unet_expert_mode_allocations:
313315
unet_vram_str = unet_compute_device
314316
unet_alloc = f"{unet_expert_mode_allocations}#{unet_vram_str}" if unet_expert_mode_allocations or unet_vram_str else ""
315-
317+
316318
clip_vram_str = ""
317319
if clip_virtual_vram_gb > 0:
318320
clip_vram_str = f"{clip_compute_device};{clip_virtual_vram_gb};{clip_donor_device}"
@@ -327,6 +329,6 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
327329
'unet_settings': hashlib.sha256(f"{unet_alloc}{high_precision_loras}".encode()).hexdigest(),
328330
'clip_settings': hashlib.sha256(f"{clip_alloc}{high_precision_loras}".encode()).hexdigest(),
329331
}
330-
332+
331333
from nodes import CheckpointLoaderSimple
332334
return CheckpointLoaderSimple().load_checkpoint(ckpt_name)

ci/extract_allocation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33

44
import argparse
55
import json
6+
import sys
67
from pathlib import Path
78
from typing import Iterable, Iterator, Dict, Any
89

910

11+
def _write_stdout(message: str = "") -> None:
12+
sys.stdout.write(f"{message}\n")
13+
sys.stdout.flush()
14+
15+
1016
def load_json_lines(path: Path) -> Iterator[Dict[str, Any]]:
1117
with path.open("r", encoding="utf-8") as handle:
1218
for line in handle:
@@ -37,12 +43,12 @@ def main() -> int:
3743

3844
entries = list(load_json_lines(args.logfile))
3945
if not entries:
40-
print("No entries found in log file.")
46+
_write_stdout("No entries found in log file.")
4147
return 0
4248

4349
matched = [entry for entry in entries if is_allocation_event(entry, args.keywords)]
4450
if not matched:
45-
print("No allocation events matched provided keywords.")
51+
_write_stdout("No allocation events matched provided keywords.")
4652
return 0
4753

4854
for entry in matched:
@@ -51,9 +57,9 @@ def main() -> int:
5157
component = entry.get("component", "")
5258
header_bits = [bit for bit in (timestamp, category, component) if bit]
5359
header = " | ".join(header_bits) if header_bits else "allocation"
54-
print(f"## {header}")
55-
print(entry.get("message", ""))
56-
print()
60+
_write_stdout(f"## {header}")
61+
_write_stdout(entry.get("message", ""))
62+
_write_stdout()
5763

5864
return 0
5965

0 commit comments

Comments
 (0)