Skip to content

Commit 0d5e9d8

Browse files
committed
Let cache file get device name from runner.get_device_info() instead of runner.dev.name
1 parent 0486fa5 commit 0d5e9d8

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

kernel_tuner/runners/parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ def __init__(
217217
print(f" - worker {worker}")
218218

219219
def get_device_info(self):
220-
# TODO: Get this from the device?
221-
return Options({"max_threads": 1024})
220+
# TODO: Get max_threads from the device?
221+
return Options({"name": self.device_name, "max_threads": 1024})
222222

223223
def get_environment(self, tuning_options):
224224
return {"device_name": self.device_name, "workers": [w.env for w in self.workers]}

kernel_tuner/util.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,13 +1259,15 @@ def process_cache(cachefile, kernel_options, tuning_options, runner):
12591259
if not isinstance(tuning_options.tune_params, dict):
12601260
raise ValueError("Caching only works correctly when tunable parameters are stored in a dictionary")
12611261

1262+
device_name = runner.get_device_info().name
1263+
12621264
# if file does not exist, create new cache
12631265
if not os.path.isfile(cachefile):
12641266
if tuning_options.simulation_mode:
12651267
raise ValueError(f"Simulation mode requires an existing cachefile: file {cachefile} does not exist")
12661268

12671269
c = dict()
1268-
c["device_name"] = runner.dev.name
1270+
c["device_name"] = device_name
12691271
c["kernel_name"] = kernel_options.kernel_name
12701272
c["problem_size"] = kernel_options.problem_size if not callable(kernel_options.problem_size) else "callable"
12711273
c["tune_params_keys"] = list(tuning_options.tune_params.keys())
@@ -1290,9 +1292,9 @@ def process_cache(cachefile, kernel_options, tuning_options, runner):
12901292
runner.dev.name = cached_data["device_name"]
12911293

12921294
# check if it is safe to continue tuning from this cache
1293-
if cached_data["device_name"] != runner.dev.name:
1295+
if cached_data["device_name"] != device_name:
12941296
raise ValueError(
1295-
f"Cannot load cache which contains results for different device (cache: {cached_data['device_name']}, actual: {runner.dev.name})"
1297+
f"Cannot load cache which contains results for different device (cache: {cached_data['device_name']}, actual: {device_name})"
12961298
)
12971299
if cached_data["kernel_name"] != kernel_options.kernel_name:
12981300
raise ValueError(

0 commit comments

Comments
 (0)