|
29 | 29 | # Global device state management |
30 | 30 | current_device = mm.get_torch_device() |
31 | 31 | current_text_encoder_device = mm.text_encoder_device() |
32 | | -current_text_encoder_initial_device = mm.text_encoder_device() |
33 | 32 |
|
34 | 33 | def set_current_device(device): |
35 | 34 | global current_device |
36 | 35 | current_device = device |
37 | 36 | logger.info(f"[MultiGPU Initialization] current_device set to: {device}") |
38 | 37 |
|
39 | 38 | def set_current_text_encoder_device(device): |
40 | | - global current_text_encoder_device, current_text_encoder_initial_device |
| 39 | + global current_text_encoder_device |
41 | 40 | current_text_encoder_device = device |
42 | | - current_text_encoder_initial_device = device |
43 | | - logger.info(f"[MultiGPU Initialization] current_text_encoder_device and current_text_encoder_initial_device set to: {device}") |
| 41 | + logger.info(f"[MultiGPU Initialization] current_text_encoder_device set to: {device}") |
44 | 42 |
|
45 | 43 | def override_class(cls): |
46 | 44 | class NodeOverride(cls): |
@@ -137,19 +135,12 @@ def text_encoder_device_patched(): |
137 | 135 | logger.debug(f"[MultiGPU Core Patching] text_encoder_device_patched returning device: {device} (current_text_encoder_device={current_text_encoder_device})") |
138 | 136 | return device |
139 | 137 |
|
140 | | -def text_encoder_initial_device_patched(*args, **kwargs): |
141 | | - logger.debug(f"[MultiGPU Core Patching] text_encoder_initial_device_patched called with args={args}, kwargs={kwargs}") |
142 | | - # look at this later - I am not convinced that this isn't the better choice: |
143 | | - # return text_encoder_device_patched() |
144 | | - return mm.text_encoder_device() |
145 | | - |
146 | 138 |
|
147 | 139 | logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, and mm.text_encoder_initial_device") |
148 | 140 | logger.debug(f"[MultiGPU DEBUG] Initial current_device: {current_device}") |
149 | 141 | logger.debug(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}") |
150 | 142 | mm.get_torch_device = get_torch_device_patched |
151 | 143 | mm.text_encoder_device = text_encoder_device_patched |
152 | | -mm.text_encoder_initial_device = text_encoder_initial_device_patched |
153 | 144 |
|
154 | 145 | def check_module_exists(module_path): |
155 | 146 | full_path = os.path.join(folder_paths.get_folder_paths("custom_nodes")[0], module_path) |
|
0 commit comments