|
2 | 2 | from warnings import warn |
3 | 3 |
|
4 | 4 | import numpy as np |
| 5 | +import os |
5 | 6 |
|
6 | 7 | from kernel_tuner.backends.backend import GPUBackend |
7 | 8 | from kernel_tuner.observers.nvcuda import CudaRuntimeObserver |
8 | 9 | from kernel_tuner.util import SkippableFailure |
9 | | -from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc |
| 10 | +from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc, find_cuda_home |
10 | 11 |
|
11 | 12 | # embedded in try block to be able to generate documentation |
12 | 13 | # and run tests without cuda-python installed |
@@ -74,9 +75,6 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None |
74 | 75 | self.current_module = None |
75 | 76 | self.func = None |
76 | 77 | self.compiler_options = compiler_options or [] |
77 | | - self.compiler_options_bytes = [] |
78 | | - for option in self.compiler_options: |
79 | | - self.compiler_options_bytes.append(str(option).encode("UTF-8")) |
80 | 78 |
|
81 | 79 | # create a stream and events |
82 | 80 | err, self.stream = driver.cuStreamCreate(0) |
@@ -155,45 +153,54 @@ def compile(self, kernel_instance): |
155 | 153 | kernel_string = kernel_instance.kernel_string |
156 | 154 | kernel_name = kernel_instance.name |
157 | 155 | expression_name = str.encode(kernel_name) |
| 156 | + compiler_options = list(self.compiler_options) |
158 | 157 |
|
159 | | - compiler_options = self.compiler_options_bytes |
160 | | - if not any([b"--std=" in opt for opt in compiler_options]): |
161 | | - compiler_options.append(b"--std=c++11") |
162 | | - if not any(["--std=" in opt for opt in self.compiler_options]): |
163 | | - self.compiler_options.append("--std=c++11") |
164 | | - if not any([b"--gpu-architecture=" in opt or b"-arch" in opt for opt in compiler_options]): |
165 | | - compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}".encode("UTF-8")) |
166 | | - if not any(["--gpu-architecture=" in opt or "-arch" in opt for opt in self.compiler_options]): |
167 | | - self.compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}") |
| 158 | + # Add -std=c++11 |
| 159 | + if not any(opt.startswith(("-std=", "--std=")) for opt in self.compiler_options): |
| 160 | + compiler_options.append("--std=c++11") |
| 161 | + |
| 162 | + # Add -arch |
| 163 | + if not any(opt.startswith(("-arch", "--arch", "--gpu-architecture=")) for opt in self.compiler_options): |
| 164 | + arch_val = to_valid_nvrtc_gpu_arch_cc(self.cc) |
| 165 | + compiler_options.append(f"--gpu-architecture=compute_{arch_val}") |
| 166 | + |
| 167 | + # Add CUDA home to include path |
| 168 | + cuda_home = find_cuda_home() |
| 169 | + if cuda_home: |
| 170 | + cuda_include = os.path.join(cuda_home, "include") |
| 171 | + compiler_options.append(f"-I{cuda_include}") |
| 172 | + |
| 173 | + # nvrtcCompileProgram requires bytes instead of str |
| 174 | + compiler_options = [str(opt).encode("UTF-8") for opt in compiler_options] |
168 | 175 |
|
169 | 176 | err, program = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"CUDAProgram", 0, [], []) |
170 | 177 | try: |
171 | 178 | # Add the kernel as an expression. This is necessary for templated kernels to ensure that the |
172 | 179 | # compiler actually instantiates the kernel that we want to compile. |
173 | 180 | cuda_error_check(err) |
174 | 181 | err = nvrtc.nvrtcAddNameExpression(program, expression_name) |
175 | | - |
| 182 | + |
176 | 183 | # Compile the program |
177 | 184 | cuda_error_check(err) |
178 | 185 | err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options) |
179 | | - |
| 186 | + |
180 | 187 | # Get the PTX |
181 | 188 | cuda_error_check(err) |
182 | 189 | err, size = nvrtc.nvrtcGetPTXSize(program) |
183 | 190 | cuda_error_check(err) |
184 | 191 | buff = b" " * size |
185 | 192 | err = nvrtc.nvrtcGetPTX(program, buff) |
186 | 193 | cuda_error_check(err) |
187 | | - |
| 194 | + |
188 | 195 | # Load the module |
189 | 196 | err, self.current_module = driver.cuModuleLoadData(np.char.array(buff)) |
190 | 197 | if err == driver.CUresult.CUDA_ERROR_INVALID_PTX: |
191 | 198 | raise SkippableFailure("uses too much shared data") |
192 | 199 | else: |
193 | 200 | cuda_error_check(err) |
194 | | - |
| 201 | + |
195 | 202 | # First, get the "lowered" name of the kernel (i.e., the name inside the PTX). |
196 | | - # After, we can use the lowered name to lookup the kernel in the module. |
| 203 | + # After, we can use the lowered name to lookup the kernel in the module. |
197 | 204 | err, lowered_name = nvrtc.nvrtcGetLoweredName(program, expression_name) |
198 | 205 | cuda_error_check(err) |
199 | 206 | err, self.func = driver.cuModuleGetFunction( |
|
0 commit comments