Skip to content

Commit c713f63

Browse files
committed
fix: ensure model.device is set in ModelPatcher.partially_load
Assign self.model.device = device_to during DisTorch V2 partially_load so the model's device reflects the target allocation after loading. AssertionError: Input tensors must be on cuda. Fixes #119 Possible issue when used with custom samplers Fixes #130
1 parent 35e81e9 commit c713f63

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

distorch_2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
301301

302302
self.model.current_weight_patches_uuid = self.patches_uuid
303303

304+
self.model.device = device_to
305+
304306
logger.info("[MultiGPU DisTorch V2] DisTorch loading completed.")
305307
logger.info(f"[MultiGPU DisTorch V2] Total memory: {mem_counter / (1024 * 1024):.2f}MB")
306308

0 commit comments

Comments
 (0)