Skip to content

Commit 6dc0282

Browse files
committed
Add testcases for nvcuda backend of template kernels
1 parent cc9a518 commit 6dc0282

3 files changed

Lines changed: 86 additions & 29 deletions

File tree

kernel_tuner/backends/nvcuda.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from warnings import warn
33

44
import numpy as np
5+
import os
56

67
from kernel_tuner.backends.backend import GPUBackend
78
from kernel_tuner.observers.nvcuda import CudaRuntimeObserver
89
from kernel_tuner.util import SkippableFailure
9-
from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc
10+
from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc, find_cuda_home
1011

1112
# embedded in try block to be able to generate documentation
1213
# and run tests without cuda-python installed
@@ -74,9 +75,6 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
7475
self.current_module = None
7576
self.func = None
7677
self.compiler_options = compiler_options or []
77-
self.compiler_options_bytes = []
78-
for option in self.compiler_options:
79-
self.compiler_options_bytes.append(str(option).encode("UTF-8"))
8078

8179
# create a stream and events
8280
err, self.stream = driver.cuStreamCreate(0)
@@ -155,45 +153,54 @@ def compile(self, kernel_instance):
155153
kernel_string = kernel_instance.kernel_string
156154
kernel_name = kernel_instance.name
157155
expression_name = str.encode(kernel_name)
156+
compiler_options = list(self.compiler_options)
158157

159-
compiler_options = self.compiler_options_bytes
160-
if not any([b"--std=" in opt for opt in compiler_options]):
161-
compiler_options.append(b"--std=c++11")
162-
if not any(["--std=" in opt for opt in self.compiler_options]):
163-
self.compiler_options.append("--std=c++11")
164-
if not any([b"--gpu-architecture=" in opt or b"-arch" in opt for opt in compiler_options]):
165-
compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}".encode("UTF-8"))
166-
if not any(["--gpu-architecture=" in opt or "-arch" in opt for opt in self.compiler_options]):
167-
self.compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}")
158+
# Add -std=c++11
159+
if not any(opt.startswith(("-std=", "--std=")) for opt in self.compiler_options):
160+
compiler_options.append("--std=c++11")
161+
162+
# Add -arch
163+
if not any(opt.startswith(("-arch", "--arch", "--gpu-architecture=")) for opt in self.compiler_options):
164+
arch_val = to_valid_nvrtc_gpu_arch_cc(self.cc)
165+
compiler_options.append(f"--gpu-architecture=compute_{arch_val}")
166+
167+
# Add CUDA home to include path
168+
cuda_home = find_cuda_home()
169+
if cuda_home:
170+
cuda_include = os.path.join(cuda_home, "include")
171+
compiler_options.append(f"-I{cuda_include}")
172+
173+
# nvrtcCompileProgram requires bytes instead of str
174+
compiler_options = [str(opt).encode("UTF-8") for opt in compiler_options]
168175

169176
err, program = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"CUDAProgram", 0, [], [])
170177
try:
171178
# Add the kernel as an expression. This is necessary for templated kernels to ensure that the
172179
# compiler actually instantiates the kernel that we want to compile.
173180
cuda_error_check(err)
174181
err = nvrtc.nvrtcAddNameExpression(program, expression_name)
175-
182+
176183
# Compile the program
177184
cuda_error_check(err)
178185
err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options)
179-
186+
180187
# Get the PTX
181188
cuda_error_check(err)
182189
err, size = nvrtc.nvrtcGetPTXSize(program)
183190
cuda_error_check(err)
184191
buff = b" " * size
185192
err = nvrtc.nvrtcGetPTX(program, buff)
186193
cuda_error_check(err)
187-
194+
188195
# Load the module
189196
err, self.current_module = driver.cuModuleLoadData(np.char.array(buff))
190197
if err == driver.CUresult.CUDA_ERROR_INVALID_PTX:
191198
raise SkippableFailure("uses too much shared data")
192199
else:
193200
cuda_error_check(err)
194-
201+
195202
# First, get the "lowered" name of the kernel (i.e., the name inside the PTX).
196-
# After, we can use the lowered name to lookup the kernel in the module.
203+
# After, we can use the lowered name to lookup the kernel in the module.
197204
err, lowered_name = nvrtc.nvrtcGetLoweredName(program, expression_name)
198205
cuda_error_check(err)
199206
err, self.func = driver.cuModuleGetFunction(

kernel_tuner/utils/nvcuda.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Module for kernel tuner cuda-python utility functions."""
22

33
import numpy as np
4+
import os
5+
import subprocess
6+
import shutil
7+
from typing import Optional
48

59
try:
610
from cuda.bindings import driver, runtime, nvrtc
@@ -56,12 +60,20 @@ def cuda_error_check(error):
5660
if error != nvrtc.nvrtcResult.NVRTC_SUCCESS:
5761
_, desc = nvrtc.nvrtcGetErrorString(error)
5862
raise RuntimeError(f"NVRTC error: {desc.decode()}")
59-
elif isinstance(error, tuple) and len(error) > 0:
60-
cuda_error_check(error[0])
61-
else:
62-
raise RuntimeError(f"unknown error type returned by CUDA: {error!r} (type: {type(error).__name__})")
6363

6464

6565
def to_valid_nvrtc_gpu_arch_cc(compute_capability: str) -> str:
6666
"""Returns a valid Compute Capability for NVRTC `--gpu-architecture=`, as per https://docs.nvidia.com/cuda/nvrtc/index.html#group__options."""
6767
return max(NVRTC_VALID_CC[NVRTC_VALID_CC <= compute_capability], default="75")
68+
69+
70+
def find_cuda_home() -> Optional[str]:
71+
"""
72+
Finds the CUDA home directory by checking environment variables.
73+
"""
74+
for var in ["CUDA_HOME", "CUDA_PATH", "CUDA_ROOT"]:
75+
path = os.environ.get(var)
76+
if path and os.path.exists(path):
77+
return path
78+
79+
return None

test/test_cuda_functions.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def test_ready_argument_list():
3232
assert isinstance(gpu_args[2], driver.CUdeviceptr)
3333

3434

35+
def create_kernel_instance(kernel_name, kernel_string):
36+
kernel_sources = KernelSource(kernel_name, kernel_string, "cuda")
37+
kernel_instance = KernelInstance(kernel_name, kernel_sources, kernel_string, [], None, None, dict(), [])
38+
return kernel_instance
39+
40+
3541
@skip_if_no_cuda
3642
def test_compile():
3743

@@ -44,15 +50,47 @@ def test_compile():
4450
}
4551
"""
4652

47-
kernel_name = "vector_add"
48-
kernel_sources = KernelSource(kernel_name, kernel_string, "cuda")
49-
kernel_instance = KernelInstance(kernel_name, kernel_sources, kernel_string, [], None, None, dict(), [])
53+
kernel_instance = create_kernel_instance("vector_add", kernel_string)
5054
dev = nvcuda.CudaFunctions(0)
51-
try:
52-
dev.compile(kernel_instance)
53-
except Exception as e:
54-
pytest.fail("Did not expect any exception:" + str(e))
55+
dev.compile(kernel_instance)
56+
57+
@skip_if_no_cuda
58+
def test_compile_template():
59+
60+
kernel_string = """
61+
namespace nested::namespaces {
62+
template <typename T, int N>
63+
__global__ void vector_add(T *c, T *a, T *b) {
64+
int i = blockIdx.x * blockDim.x + threadIdx.x;
65+
if (i<N) {
66+
c[i] = a[i] + b[i];
67+
}
68+
}
69+
}
70+
"""
71+
72+
kernel_name = "nested::namespaces::vector_add<float,10>"
73+
kernel_instance = create_kernel_instance(kernel_name, kernel_string)
74+
dev = nvcuda.CudaFunctions(0, compiler_options=["-std=c++17"])
75+
dev.compile(kernel_instance)
76+
77+
@skip_if_no_cuda
78+
def test_compile_include():
79+
80+
kernel_string = """
81+
#include <cuda_fp16.h>
82+
83+
__global__ void vector_add(__nv_half *c, __nv_half *a, __nv_half *b, int n) {
84+
int i = blockIdx.x * blockDim.x + threadIdx.x;
85+
if (i<n) {
86+
c[i] = __hadd(a[i], b[i]);
87+
}
88+
}
89+
"""
5590

91+
kernel_instance = create_kernel_instance("vector_add", kernel_string)
92+
dev = nvcuda.CudaFunctions(0, compiler_options=["-std=c++17"])
93+
dev.compile(kernel_instance)
5694

5795
@skip_if_no_cuda
5896
def test_tune_kernel(env):

0 commit comments

Comments
 (0)