Skip to content

Commit d6137f3

Browse files
authored
Merge pull request #122 from pollockjj/wvw
Pull Request: Refresh WanVideoWrapper MultiGPU support to latest upstream
2 parents b0930e2 + c77a4b2 commit d6137f3

21 files changed

Lines changed: 3616 additions & 4201 deletions

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
__pycache__/
33
.clinerules
44
.vscode
5-
memory-bank/
5+
memory-bank/
6+
.github/

__init__.py

Lines changed: 132 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import weakref
44
import os
55
import copy
6+
import json
7+
from datetime import datetime
68
from pathlib import Path
79
import folder_paths
810
import comfy.model_management as mm
@@ -21,12 +23,22 @@
2123
)
2224

2325
WEB_DIRECTORY = "./web"
24-
MGPU_MM_LOG = False
26+
MGPU_MM_LOG = True
2527
DEBUG_LOG = False
2628

2729
logger = logging.getLogger("MultiGPU")
2830
logger.propagate = False
2931

32+
FOCUS_LOG_LEVEL = logging.INFO + 5
33+
logging.addLevelName(FOCUS_LOG_LEVEL, "FOCUS")
34+
35+
if not hasattr(logging.Logger, "focus"):
36+
def focus(self, message, *args, **kwargs):
37+
if self.isEnabledFor(FOCUS_LOG_LEVEL):
38+
self._log(FOCUS_LOG_LEVEL, message, args, **kwargs)
39+
40+
logging.Logger.focus = focus # type: ignore[attr-defined]
41+
3042
if not logger.handlers:
3143
log_level = logging.DEBUG if DEBUG_LOG else logging.INFO
3244
handler = logging.StreamHandler()
@@ -35,10 +47,93 @@
3547
logger.addHandler(handler)
3648
logger.setLevel(log_level)
3749

50+
json_log_path = os.environ.get("MGPU_JSON_LOG_PATH")
51+
json_static_fields = {}
52+
if json_log_path:
53+
try:
54+
json_static_fields = json.loads(os.environ.get("MGPU_JSON_STATIC_FIELDS", "{}"))
55+
except json.JSONDecodeError:
56+
json_static_fields = {}
57+
58+
level_aliases = {
59+
"CRITICAL": logging.CRITICAL,
60+
"ERROR": logging.ERROR,
61+
"WARNING": logging.WARNING,
62+
"FOCUS": FOCUS_LOG_LEVEL,
63+
"INFO": logging.INFO,
64+
"DEBUG": logging.DEBUG,
65+
}
66+
67+
json_min_level = FOCUS_LOG_LEVEL
68+
configured_min_level = os.environ.get("MGPU_JSON_MIN_LEVEL")
69+
if configured_min_level:
70+
value = configured_min_level.strip()
71+
upper_value = value.upper()
72+
if upper_value in level_aliases:
73+
json_min_level = level_aliases[upper_value]
74+
else:
75+
try:
76+
json_min_level = int(value)
77+
except ValueError:
78+
json_min_level = FOCUS_LOG_LEVEL
79+
80+
class JsonLineFileHandler(logging.Handler):
81+
def __init__(self, path, static_fields, min_level, overwrite):
82+
super().__init__()
83+
self.path = Path(path)
84+
self.path.parent.mkdir(parents=True, exist_ok=True)
85+
self.static_fields = static_fields
86+
self.setLevel(min_level)
87+
if overwrite:
88+
try:
89+
with self.path.open("w", encoding="utf-8") as handle:
90+
handle.write("")
91+
except OSError:
92+
pass
93+
94+
def emit(self, record):
95+
message = record.getMessage()
96+
category = None
97+
if message.startswith("[") and "]" in message:
98+
bracket_split = message.split("]", 1)
99+
category = bracket_split[0].strip("[]")
100+
payload = {
101+
"timestamp": datetime.utcnow().isoformat() + "Z",
102+
"level": record.levelname,
103+
"name": record.name,
104+
"message": message,
105+
}
106+
if category:
107+
payload["event_category"] = category
108+
if hasattr(record, "mgpu_context") and isinstance(record.mgpu_context, dict):
109+
payload.update(record.mgpu_context)
110+
workflow_id = os.environ.get("MGPU_JSON_WORKFLOW")
111+
prompt_id = os.environ.get("MGPU_JSON_PROMPT")
112+
if workflow_id:
113+
payload.setdefault("workflow_id", workflow_id)
114+
if prompt_id:
115+
payload.setdefault("prompt_id", prompt_id)
116+
if self.static_fields:
117+
payload.update(self.static_fields)
118+
try:
119+
with self.path.open("a", encoding="utf-8") as handle:
120+
handle.write(json.dumps(payload, ensure_ascii=True) + "\n")
121+
except OSError:
122+
# Fail silently for JSON logging so primary logging continues.
123+
pass
124+
125+
overwrite_value = os.environ.get("MGPU_JSON_OVERWRITE", "true").strip().lower()
126+
overwrite_enabled = overwrite_value not in {"0", "false", "no"}
127+
128+
logger.addHandler(JsonLineFileHandler(json_log_path, json_static_fields, json_min_level, overwrite_enabled))
129+
38130
def mgpu_mm_log_method(self, msg):
39131
"""Add MultiGPU model management logging method to logger instance."""
40132
if MGPU_MM_LOG:
41-
self.info(f"[MultiGPU Model Management] {msg}")
133+
self.focus(
134+
f"[MultiGPU Model Management] {msg}",
135+
extra={"mgpu_context": {"component": "model_management"}},
136+
)
42137
logger.mgpu_mm_log = mgpu_mm_log_method.__get__(logger, type(logger))
43138

44139
def check_module_exists(module_path):
@@ -95,8 +190,6 @@ def text_encoder_device_patched():
95190
mm.text_encoder_device = text_encoder_device_patched
96191

97192
from .nodes import (
98-
DeviceSelectorMultiGPU,
99-
HunyuanVideoEmbeddingsAdapter,
100193
UnetLoaderGGUF,
101194
UnetLoaderGGUFAdvanced,
102195
CLIPLoaderGGUF,
@@ -114,21 +207,30 @@ def text_encoder_device_patched():
114207
PulidModelLoader,
115208
PulidInsightFaceLoader,
116209
PulidEvaClipLoader,
117-
HyVideoModelLoader,
118-
HyVideoVAELoader,
119-
DownloadAndLoadHyVideoTextEncoder,
120210
UNetLoaderLP,
121211
)
122212

123213
from .wanvideo import (
124-
WanVideoModelLoader,
125-
WanVideoModelLoader_2,
126-
WanVideoVAELoader,
127214
LoadWanVideoT5TextEncoder,
128-
LoadWanVideoClipTextEncoder,
129215
WanVideoTextEncode,
216+
WanVideoTextEncodeCached,
217+
WanVideoTextEncodeSingle,
218+
WanVideoVAELoader,
219+
WanVideoTinyVAELoader,
130220
WanVideoBlockSwap,
131-
WanVideoSampler
221+
WanVideoImageToVideoEncode,
222+
WanVideoDecode,
223+
WanVideoModelLoader,
224+
WanVideoSampler,
225+
WanVideoVACEEncode,
226+
WanVideoEncode,
227+
LoadWanVideoClipTextEncoder,
228+
WanVideoClipVisionEncode,
229+
WanVideoControlnetLoader,
230+
FantasyTalkingModelLoader,
231+
Wav2VecModelLoader,
232+
WanVideoUni3C_ControlnetLoader,
233+
DownloadAndLoadWav2VecModel,
132234
)
133235

134236
from .wrappers import (
@@ -158,8 +260,6 @@ def text_encoder_device_patched():
158260
)
159261

160262
NODE_CLASS_MAPPINGS = {
161-
"DeviceSelectorMultiGPU": DeviceSelectorMultiGPU,
162-
"HunyuanVideoEmbeddingsAdapter": HunyuanVideoEmbeddingsAdapter,
163263
"CheckpointLoaderAdvancedMultiGPU": CheckpointLoaderAdvancedMultiGPU,
164264
"CheckpointLoaderAdvancedDisTorch2MultiGPU": CheckpointLoaderAdvancedDisTorch2MultiGPU,
165265
"UNetLoaderLP": UNetLoaderLP,
@@ -266,27 +366,32 @@ def register_and_count(module_names, node_map):
266366
}
267367
register_and_count(["PuLID_ComfyUI", "pulid_comfyui"], pulid_nodes)
268368

269-
hunyuan_nodes = {
270-
"HyVideoModelLoaderMultiGPU": override_class(HyVideoModelLoader),
271-
"HyVideoVAELoaderMultiGPU": override_class(HyVideoVAELoader),
272-
"DownloadAndLoadHyVideoTextEncoderMultiGPU": override_class(DownloadAndLoadHyVideoTextEncoder)
273-
}
274-
register_and_count(["ComfyUI-HunyuanVideoWrapper", "comfyui-hunyuanvideowrapper"], hunyuan_nodes)
275-
276369
wanvideo_nodes = {
277-
"WanVideoModelLoaderMultiGPU": WanVideoModelLoader,
278-
"WanVideoModelLoaderMultiGPU_2": WanVideoModelLoader_2,
279-
"WanVideoVAELoaderMultiGPU": WanVideoVAELoader,
280370
"LoadWanVideoT5TextEncoderMultiGPU": LoadWanVideoT5TextEncoder,
281-
"LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder,
282371
"WanVideoTextEncodeMultiGPU": WanVideoTextEncode,
372+
"WanVideoTextEncodeCachedMultiGPU": WanVideoTextEncodeCached,
373+
"WanVideoTextEncodeSingleMultiGPU": WanVideoTextEncodeSingle,
374+
"WanVideoVAELoaderMultiGPU": WanVideoVAELoader,
375+
"WanVideoTinyVAELoaderMultiGPU": WanVideoTinyVAELoader,
283376
"WanVideoBlockSwapMultiGPU": WanVideoBlockSwap,
284-
"WanVideoSamplerMultiGPU": WanVideoSampler
377+
"WanVideoImageToVideoEncodeMultiGPU": WanVideoImageToVideoEncode,
378+
"WanVideoDecodeMultiGPU": WanVideoDecode,
379+
"WanVideoModelLoaderMultiGPU": WanVideoModelLoader,
380+
"WanVideoSamplerMultiGPU": WanVideoSampler,
381+
"WanVideoVACEEncodeMultiGPU": WanVideoVACEEncode,
382+
"WanVideoEncodeMultiGPU": WanVideoEncode,
383+
"LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder,
384+
"WanVideoClipVisionEncodeMultiGPU": WanVideoClipVisionEncode,
385+
"WanVideoControlnetLoaderMultiGPU": WanVideoControlnetLoader,
386+
"FantasyTalkingModelLoaderMultiGPU": FantasyTalkingModelLoader,
387+
"Wav2VecModelLoaderMultiGPU": Wav2VecModelLoader,
388+
"WanVideoUni3C_ControlnetLoaderMultiGPU": WanVideoUni3C_ControlnetLoader,
389+
"DownloadAndLoadWav2VecModelMultiGPU": DownloadAndLoadWav2VecModel,
285390
}
286391
register_and_count(["ComfyUI-WanVideoWrapper", "comfyui-wanvideowrapper"], wanvideo_nodes)
287392

288393
for item in registration_data:
289394
logger.info(fmt_reg.format(item['name'], item['found'], str(item['count'])))
290395
logger.info(dash_line)
291396

292-
logger.info(f"[MultiGPU] Registration complete. Final mappings: {', '.join(NODE_CLASS_MAPPINGS.keys())}")
397+
logger.info(f"[MultiGPU] Registration complete. Final mappings: {', '.join(NODE_CLASS_MAPPINGS.keys())}")
738 KB
Loading
768 KB
Loading
547 KB
Loading
731 KB
Loading

assets/wan2_2_benchmark_v2.png

623 KB
Loading
1.61 MB
Loading

ci/extract_allocation.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/usr/bin/env python3
2+
"""Filter MultiGPU JSON logs for allocation summaries."""
3+
4+
import argparse
5+
import json
6+
from pathlib import Path
7+
from typing import Iterable, Iterator, Dict, Any
8+
9+
10+
def load_json_lines(path: Path) -> Iterator[Dict[str, Any]]:
11+
with path.open("r", encoding="utf-8") as handle:
12+
for line in handle:
13+
line = line.strip()
14+
if not line:
15+
continue
16+
try:
17+
yield json.loads(line)
18+
except json.JSONDecodeError:
19+
continue
20+
21+
22+
def is_allocation_event(entry: Dict[str, Any], keywords: Iterable[str]) -> bool:
23+
message = entry.get("message", "")
24+
return any(keyword in message for keyword in keywords)
25+
26+
27+
def main() -> int:
28+
parser = argparse.ArgumentParser(description="Extract allocation-related events from MultiGPU JSON logs")
29+
parser.add_argument("logfile", type=Path, help="Path to JSONL log produced by MGPU_JSON_LOG_PATH")
30+
parser.add_argument(
31+
"--keywords",
32+
nargs="*",
33+
default=["Final Allocation String", "Total memory", "Virtual VRAM"],
34+
help="Keywords that mark allocation events",
35+
)
36+
args = parser.parse_args()
37+
38+
entries = list(load_json_lines(args.logfile))
39+
if not entries:
40+
print("No entries found in log file.")
41+
return 0
42+
43+
matched = [entry for entry in entries if is_allocation_event(entry, args.keywords)]
44+
if not matched:
45+
print("No allocation events matched provided keywords.")
46+
return 0
47+
48+
for entry in matched:
49+
timestamp = entry.get("timestamp", "unknown")
50+
category = entry.get("event_category", "")
51+
component = entry.get("component", "")
52+
header_bits = [bit for bit in (timestamp, category, component) if bit]
53+
header = " | ".join(header_bits) if header_bits else "allocation"
54+
print(f"## {header}")
55+
print(entry.get("message", ""))
56+
print()
57+
58+
return 0
59+
60+
61+
if __name__ == "__main__":
62+
raise SystemExit(main())

0 commit comments

Comments
 (0)