Skip to content

Commit 291a4a4

Browse files
committed
feat: Add support for Apple MPS devices
Update the `get_device_list` function to detect and include the 'mps' (Metal Performance Shaders) backend if it's available through PyTorch. This allows users on Apple Silicon hardware to see and select their GPU for accelerated computations.
1 parent 545da7f commit 291a4a4

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def get_device_list():
4747
devs += [f"xpu:{i}" for i in range(torch.xpu.device_count())]
4848
except Exception:
4949
pass
50+
try:
51+
if torch.backends.mps.is_available():
52+
devs += ["mps"]
53+
except Exception:
54+
pass
5055
return devs
5156

5257
def set_current_device(device):

0 commit comments

Comments
 (0)