Skip to content

Commit 0486fa5

Browse files
committed
Change NVMLObserver to perform initialization inside register_device instead of __init__
1 parent f3d84ff commit 0486fa5

1 file changed

Lines changed: 42 additions & 12 deletions

File tree

kernel_tuner/observers/nvml.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,29 @@
1515
class nvml:
1616
"""Class that gathers the NVML functionality for one device."""
1717

18-
def __init__(self, device_id=0, nvidia_smi_fallback="nvidia-smi", use_locked_clocks=False):
18+
def __init__(
19+
self,
20+
device_id=None,
21+
device_uuid=None,
22+
device_pci_bus=None,
23+
nvidia_smi_fallback=None,
24+
use_locked_clocks=False
25+
):
1926
"""Create object to control device using NVML."""
2027
pynvml.nvmlInit()
21-
self.dev = pynvml.nvmlDeviceGetHandleByIndex(device_id)
22-
self.id = device_id
23-
self.nvidia_smi = nvidia_smi_fallback
28+
29+
if sum(x is not None for x in [device_id, device_uuid, device_pci]) != 1:
30+
raise ValueError("invalid device: specify either the index, the UUID, or the PCI-bus")
31+
elif device_id:
32+
self.dev = pynvml.nvmlDeviceGetHandleByIndex(device_id)
33+
elif device_uuid:
34+
self.dev = pynvml.nvmlDeviceGetHandleByUUID(device_uuid)
35+
elif device_pci_bus:
36+
self.dev = pynvml.nvmlDeviceGetHandleByPciBusId_v2(device_pci_bus)
37+
38+
39+
self.id = pynvml.nvmlDeviceGetIndex(self.dev)
40+
self.nvidia_smi = nvidia_smi_fallback or "nvidia-smi"
2441

2542
try:
2643
self.pwr_limit_default = pynvml.nvmlDeviceGetPowerManagementLimit(self.dev)
@@ -326,15 +343,11 @@ def __init__(
326343
continuous_duration=1,
327344
):
328345
"""Create an NVMLObserver."""
329-
if nvidia_smi_fallback:
330-
self.nvml = nvml(
331-
device,
332-
nvidia_smi_fallback=nvidia_smi_fallback,
333-
use_locked_clocks=use_locked_clocks,
334-
)
335-
else:
336-
self.nvml = nvml(device, use_locked_clocks=use_locked_clocks)
337346
self.save_all = save_all
347+
self.device = device
348+
self.nvml_kwargs = dict()
349+
self.nvml_kwargs["use_locked_clocks"] = use_locked_clocks
350+
self.nvml_kwargs["nvidia_smi_fallback"] = nvidia_smi_fallback
338351

339352
supported = [
340353
"power_readings",
@@ -373,6 +386,23 @@ def __init__(
373386
self.during_obs = [obs for obs in observables if obs in ["core_freq", "mem_freq", "temperature"]]
374387
self.iteration = {obs: [] for obs in self.during_obs}
375388

389+
def register_device(self, dev):
390+
if self.device is not None:
391+
self.nvml = nvml(device_id=self.device, **self.nvml_kwargs)
392+
else:
393+
env = getattr(dev, "env", dict())
394+
uuid = env.get("uuid")
395+
pci_bus = env.get("pci_bus_id")
396+
397+
if uuid:
398+
self.nvml = nvml(device_uuid=uuid, **self.nvml_kwargs)
399+
elif pci_bus:
400+
self.nvml = nvml(device_pci_bus=pci_bus, **self.nvml_kwargs)
401+
else:
402+
raise ValueError("failed to detect NVIDIA device: no UUID or PCI-bus-id in environment")
403+
404+
405+
376406
def read_power(self):
377407
""" Return power in Watt """
378408
return self.nvml.pwr_usage() / 1e3

0 commit comments

Comments
 (0)