|
15 | 15 | class nvml: |
16 | 16 | """Class that gathers the NVML functionality for one device.""" |
17 | 17 |
|
18 | | - def __init__(self, device_id=0, nvidia_smi_fallback="nvidia-smi", use_locked_clocks=False): |
| 18 | + def __init__( |
| 19 | + self, |
| 20 | + device_id=None, |
| 21 | + device_uuid=None, |
| 22 | + device_pci_bus=None, |
| 23 | + nvidia_smi_fallback=None, |
| 24 | + use_locked_clocks=False |
| 25 | + ): |
19 | 26 | """Create object to control device using NVML.""" |
20 | 27 | pynvml.nvmlInit() |
21 | | - self.dev = pynvml.nvmlDeviceGetHandleByIndex(device_id) |
22 | | - self.id = device_id |
23 | | - self.nvidia_smi = nvidia_smi_fallback |
| 28 | + |
| 29 | + if sum(x is not None for x in [device_id, device_uuid, device_pci]) != 1: |
| 30 | + raise ValueError("invalid device: specify either the index, the UUID, or the PCI-bus") |
| 31 | + elif device_id: |
| 32 | + self.dev = pynvml.nvmlDeviceGetHandleByIndex(device_id) |
| 33 | + elif device_uuid: |
| 34 | + self.dev = pynvml.nvmlDeviceGetHandleByUUID(device_uuid) |
| 35 | + elif device_pci_bus: |
| 36 | + self.dev = pynvml.nvmlDeviceGetHandleByPciBusId_v2(device_pci_bus) |
| 37 | + |
| 38 | + |
| 39 | + self.id = pynvml.nvmlDeviceGetIndex(self.dev) |
| 40 | + self.nvidia_smi = nvidia_smi_fallback or "nvidia-smi" |
24 | 41 |
|
25 | 42 | try: |
26 | 43 | self.pwr_limit_default = pynvml.nvmlDeviceGetPowerManagementLimit(self.dev) |
@@ -326,15 +343,11 @@ def __init__( |
326 | 343 | continuous_duration=1, |
327 | 344 | ): |
328 | 345 | """Create an NVMLObserver.""" |
329 | | - if nvidia_smi_fallback: |
330 | | - self.nvml = nvml( |
331 | | - device, |
332 | | - nvidia_smi_fallback=nvidia_smi_fallback, |
333 | | - use_locked_clocks=use_locked_clocks, |
334 | | - ) |
335 | | - else: |
336 | | - self.nvml = nvml(device, use_locked_clocks=use_locked_clocks) |
337 | 346 | self.save_all = save_all |
| 347 | + self.device = device |
| 348 | + self.nvml_kwargs = dict() |
| 349 | + self.nvml_kwargs["use_locked_clocks"] = use_locked_clocks |
| 350 | + self.nvml_kwargs["nvidia_smi_fallback"] = nvidia_smi_fallback |
338 | 351 |
|
339 | 352 | supported = [ |
340 | 353 | "power_readings", |
@@ -373,6 +386,23 @@ def __init__( |
373 | 386 | self.during_obs = [obs for obs in observables if obs in ["core_freq", "mem_freq", "temperature"]] |
374 | 387 | self.iteration = {obs: [] for obs in self.during_obs} |
375 | 388 |
|
| 389 | + def register_device(self, dev): |
| 390 | + if self.device is not None: |
| 391 | + self.nvml = nvml(device_id=self.device, **self.nvml_kwargs) |
| 392 | + else: |
| 393 | + env = getattr(dev, "env", dict()) |
| 394 | + uuid = env.get("uuid") |
| 395 | + pci_bus = env.get("pci_bus_id") |
| 396 | + |
| 397 | + if uuid: |
| 398 | + self.nvml = nvml(device_uuid=uuid, **self.nvml_kwargs) |
| 399 | + elif pci_bus: |
| 400 | + self.nvml = nvml(device_pci_bus=pci_bus, **self.nvml_kwargs) |
| 401 | + else: |
| 402 | + raise ValueError("failed to detect NVIDIA device: no UUID or PCI-bus-id in environment") |
| 403 | + |
| 404 | + |
| 405 | + |
376 | 406 | def read_power(self): |
377 | 407 | """ Return power in Watt """ |
378 | 408 | return self.nvml.pwr_usage() / 1e3 |
|
0 commit comments