Skip to content

Commit 7763da9

Browse files
committed
feat: dyn dispatch plan
1 parent 3c49fae commit 7763da9

7 files changed

Lines changed: 254 additions & 251 deletions

File tree

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ use vortex_cuda::CudaExecutionCtx;
2626
use vortex_cuda::CudaSession;
2727
use vortex_cuda::bitpacked_cuda_kernel;
2828
use vortex_cuda::bitpacked_cuda_launch_config;
29-
use vortex_cuda::dynamic_dispatch_op::DynamicOp;
30-
use vortex_cuda::dynamic_dispatch_op::DynamicOpCode_ALP;
31-
use vortex_cuda::dynamic_dispatch_op::DynamicOpCode_BITUNPACK;
32-
use vortex_cuda::dynamic_dispatch_op::DynamicOpCode_FOR;
29+
use vortex_cuda::dynamic_dispatch::DynamicDispatchPlan;
30+
use vortex_cuda::dynamic_dispatch::ScalarOp;
31+
use vortex_cuda::dynamic_dispatch::SourceOp;
3332
use vortex_cuda_macros::cuda_available;
3433
use vortex_cuda_macros::cuda_not_available;
3534
use vortex_dtype::PType;
@@ -54,10 +53,6 @@ const ALP_E: f32 = 1.0;
5453
// Helpers
5554
// ---------------------------------------------------------------------------
5655

57-
fn pack_alp_f32_param(f: f32, e: f32) -> u64 {
58-
(e.to_bits() as u64) << 32 | f.to_bits() as u64
59-
}
60-
6156
/// Helper: launch a single FoR kernel on a device buffer (in-place).
6257
fn launch_for_kernel(
6358
cuda_ctx: &mut CudaExecutionCtx,
@@ -107,12 +102,11 @@ fn run_dynamic_dispatch_timed(
107102
input_ptr: u64,
108103
output_ptr: u64,
109104
array_len: usize,
110-
device_ops: &Arc<cudarc::driver::CudaSlice<DynamicOp>>,
111-
num_ops: u8,
105+
device_plan: &Arc<cudarc::driver::CudaSlice<DynamicDispatchPlan>>,
112106
) -> VortexResult<Duration> {
113107
let cuda_function = cuda_ctx.load_function("dynamic_dispatch", &["u32"])?;
114108
let array_len_u64 = array_len as u64;
115-
let ops_ptr = device_ops.device_ptr(cuda_ctx.stream()).0;
109+
let plan_ptr = device_plan.device_ptr(cuda_ctx.stream()).0;
116110

117111
let stream = cuda_ctx.stream();
118112
let ctx = stream.context();
@@ -127,8 +121,7 @@ fn run_dynamic_dispatch_timed(
127121
launch_builder.arg(&input_ptr);
128122
launch_builder.arg(&output_ptr);
129123
launch_builder.arg(&array_len_u64);
130-
launch_builder.arg(&ops_ptr);
131-
launch_builder.arg(&num_ops);
124+
launch_builder.arg(&plan_ptr);
132125

133126
let num_blocks = array_len.div_ceil(2048) as u32;
134127
let config = LaunchConfig {
@@ -275,15 +268,14 @@ fn bench_bitunpack_for_separate(c: &mut Criterion) {
275268
}
276269

277270
// ============================================================================
278-
// Benchmark: BitUnpack + FoR — single fused dynamic scalar_decode launch
271+
// Benchmark: BitUnpack + FoR — single fused dynamic dispatch launch
279272
// ============================================================================
280273

281274
/// Run a fused dynamic_dispatch launch on a bitpacked array, returning GPU time.
282275
fn run_dynamic_dispatch_bitpacked_timed(
283276
cuda_ctx: &mut CudaExecutionCtx,
284277
bitpacked_array: &BitPackedArray,
285-
device_ops: &Arc<cudarc::driver::CudaSlice<DynamicOp>>,
286-
num_ops: u8,
278+
device_plan: &Arc<cudarc::driver::CudaSlice<DynamicDispatchPlan>>,
287279
) -> VortexResult<Duration> {
288280
let packed = bitpacked_array.packed().clone();
289281
let len = bitpacked_array.len();
@@ -314,24 +306,17 @@ fn run_dynamic_dispatch_bitpacked_timed(
314306
.synchronize()
315307
.map_err(|e| vortex_err!("failed to synchronize stream: {:?}", e))?;
316308

317-
run_dynamic_dispatch_timed(cuda_ctx, input_ptr, output_ptr, len, device_ops, num_ops)
309+
run_dynamic_dispatch_timed(cuda_ctx, input_ptr, output_ptr, len, device_plan)
318310
}
319311

320312
fn bench_bitunpack_for_dynamic_dispatch(c: &mut Criterion) {
321313
let mut group = c.benchmark_group("bitunpack_for");
322314
group.sample_size(10);
323315

324-
// ops = [BITUNPACK(bit_width=BIT_WIDTH), FOR(REFERENCE_VALUE)]
325-
let ops = vec![
326-
DynamicOp {
327-
op: DynamicOpCode_BITUNPACK,
328-
param: BIT_WIDTH as u64,
329-
},
330-
DynamicOp {
331-
op: DynamicOpCode_FOR,
332-
param: REFERENCE_VALUE as u64,
333-
},
334-
];
316+
let plan = DynamicDispatchPlan::new(
317+
SourceOp::bitunpack(BIT_WIDTH),
318+
&[ScalarOp::frame_of_ref(REFERENCE_VALUE as u64)],
319+
);
335320

336321
for (len, len_str) in BENCH_ARGS {
337322
group.throughput(Throughput::Bytes((len * size_of::<u32>()) as u64));
@@ -350,11 +335,11 @@ fn bench_bitunpack_for_dynamic_dispatch(c: &mut Criterion) {
350335
.load_function("dynamic_dispatch", &["u32"])
351336
.vortex_expect("failed to preload dynamic_dispatch kernel");
352337

353-
let device_ops = Arc::new(
338+
let device_plan = Arc::new(
354339
cuda_ctx
355340
.stream()
356-
.clone_htod(ops.as_slice())
357-
.expect("failed to copy ops to device"),
341+
.clone_htod(std::slice::from_ref(&plan))
342+
.expect("failed to copy plan to device"),
358343
);
359344

360345
b.iter_custom(|iters| {
@@ -364,8 +349,7 @@ fn bench_bitunpack_for_dynamic_dispatch(c: &mut Criterion) {
364349
let kernel_time = run_dynamic_dispatch_bitpacked_timed(
365350
&mut cuda_ctx,
366351
array,
367-
&device_ops,
368-
ops.len() as u8,
352+
&device_plan,
369353
)
370354
.vortex_expect("bitunpack+for dynamic_dispatch failed");
371355
total_time += kernel_time;
@@ -388,21 +372,13 @@ fn bench_bitunpack_for_alp_dynamic_dispatch(c: &mut Criterion) {
388372
let mut group = c.benchmark_group("bitunpack_for_alp");
389373
group.sample_size(10);
390374

391-
// ops = [BITUNPACK(bit_width), FOR(reference), ALP(f, e)]
392-
let ops = vec![
393-
DynamicOp {
394-
op: DynamicOpCode_BITUNPACK,
395-
param: BIT_WIDTH as u64,
396-
},
397-
DynamicOp {
398-
op: DynamicOpCode_FOR,
399-
param: REFERENCE_VALUE as u64,
400-
},
401-
DynamicOp {
402-
op: DynamicOpCode_ALP,
403-
param: pack_alp_f32_param(ALP_F, ALP_E),
404-
},
405-
];
375+
let plan = DynamicDispatchPlan::new(
376+
SourceOp::bitunpack(BIT_WIDTH),
377+
&[
378+
ScalarOp::frame_of_ref(REFERENCE_VALUE as u64),
379+
ScalarOp::alp(ALP_F, ALP_E),
380+
],
381+
);
406382

407383
for (len, len_str) in BENCH_ARGS {
408384
group.throughput(Throughput::Bytes((len * size_of::<u32>()) as u64));
@@ -421,11 +397,11 @@ fn bench_bitunpack_for_alp_dynamic_dispatch(c: &mut Criterion) {
421397
.load_function("dynamic_dispatch", &["u32"])
422398
.vortex_expect("failed to preload dynamic_dispatch kernel");
423399

424-
let device_ops = Arc::new(
400+
let device_plan = Arc::new(
425401
cuda_ctx
426402
.stream()
427-
.clone_htod(ops.as_slice())
428-
.expect("failed to copy ops to device"),
403+
.clone_htod(std::slice::from_ref(&plan))
404+
.expect("failed to copy plan to device"),
429405
);
430406

431407
b.iter_custom(|iters| {
@@ -435,8 +411,7 @@ fn bench_bitunpack_for_alp_dynamic_dispatch(c: &mut Criterion) {
435411
let kernel_time = run_dynamic_dispatch_bitpacked_timed(
436412
&mut cuda_ctx,
437413
array,
438-
&device_ops,
439-
ops.len() as u8,
414+
&device_plan,
440415
)
441416
.vortex_expect("bitunpack+for+alp dynamic_dispatch failed");
442417
total_time += kernel_time;

vortex-cuda/build.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,24 +182,22 @@ fn nvcc_compile_ptx(
182182

183183
/// Generate bindings for the dynamic dispatch shared header.
184184
///
185-
/// `DynamicOp` and `DynamicOpCode` are shared between CUDA kernels
185+
/// `DynamicDispatchPlan` and related types are shared between CUDA kernels
186186
/// and Rust host code.
187187
fn generate_dynamic_dispatch_bindings(kernels_src: &Path, out_dir: &Path) {
188188
let header = kernels_src.join("dynamic_dispatch.h");
189189
println!("cargo:rerun-if-changed={}", header.display());
190190

191191
let bindings = bindgen::Builder::default()
192192
.header(header.to_string_lossy())
193-
.allowlist_type("DynamicOp")
194-
.allowlist_type("DynamicOpCode")
195193
.derive_copy(true)
196194
.derive_debug(true)
197195
.generate()
198196
.expect("Failed to generate dynamic_dispatch bindings");
199197

200198
bindings
201-
.write_to_file(out_dir.join("dynamic_dispatch_op.rs"))
202-
.expect("Failed to write dynamic_dispatch_op.rs");
199+
.write_to_file(out_dir.join("dynamic_dispatch.rs"))
200+
.expect("Failed to write dynamic_dispatch.rs");
203201
}
204202

205203
/// Check if CUDA is available based on nvcc.

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 38 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
// Dynamic dispatch kernel: decodes an array by applying a sequence of operations
5-
// in a single kernel launch. The first op may optionally be a "source" op, e.g. bitunpack.
6-
// Subsequent transform ops are applied element-wise in registers.
5+
// in a single kernel launch. The source op fills shared memory (e.g. bitunpack),
6+
// then scalar ops are applied element-wise in registers (e.g. FoR, zigzag, ALP).
77

88
#include <assert.h>
99
#include <cuda.h>
@@ -17,32 +17,8 @@
1717
#include "dynamic_dispatch.h"
1818
#include "types.cuh"
1919

20-
constexpr uint8_t MAX_DECODE_OPS = 8;
2120
constexpr uint32_t FL_CHUNK_SIZE = 1024;
2221

23-
__device__ __forceinline__ bool is_source_op(enum DynamicOpCode op) {
24-
return op == BITUNPACK;
25-
}
26-
27-
template <typename T>
28-
__device__ __forceinline__ T apply_scalar_op(T value, const DynamicOp &op) {
29-
switch (op.op) {
30-
case FOR: {
31-
return value + static_cast<T>(op.param);
32-
}
33-
case ZIGZAG: {
34-
return (value >> 1) ^ static_cast<T>(-(value & 1));
35-
}
36-
case ALP: {
37-
float f_val = __uint_as_float(static_cast<uint32_t>(op.param));
38-
float e_val = __uint_as_float(static_cast<uint32_t>(op.param >> 32));
39-
float result = static_cast<float>(static_cast<int32_t>(value)) * f_val * e_val;
40-
return static_cast<T>(__float_as_uint(result));
41-
}
42-
default: __builtin_unreachable();
43-
}
44-
}
45-
4622
template <typename T>
4723
__device__ __forceinline__ void bitunpack_lane_to_smem(const T *__restrict packed_chunk, T *__restrict smem,
4824
unsigned int lane, uint32_t bit_width);
@@ -65,15 +41,15 @@ BITUNPACK_LANE(64, uint64_t, uint64_t)
6541
BITUNPACK_LANE(64, uint64_t, int64_t)
6642

6743
template <typename T>
68-
__device__ __forceinline__ void source_fill_op(const T *__restrict input, T *__restrict smem,
44+
__device__ __forceinline__ void dynamic_source_op(const T *__restrict input, T *__restrict smem,
6945
uint64_t chunk_start, uint32_t chunk_len,
70-
const DynamicOp &source_op) {
46+
const struct SourceOp &source_op) {
7147
constexpr uint32_t T_BITS = sizeof(T) * 8;
7248
constexpr uint32_t FL_LANES = FL_CHUNK_SIZE / T_BITS;
7349

74-
switch (source_op.op) {
75-
case BITUNPACK: {
76-
const uint32_t bit_width = static_cast<uint32_t>(source_op.param);
50+
switch (source_op.op_code) {
51+
case SourceOp::BITUNPACK: {
52+
const uint32_t bit_width = source_op.params.bitunpack.bit_width;
7753
const uint32_t packed_words_per_chunk = FL_LANES * bit_width;
7854
const uint64_t chunk_idx = chunk_start / FL_CHUNK_SIZE;
7955
const T *packed_chunk = input + chunk_idx * packed_words_per_chunk;
@@ -82,44 +58,50 @@ __device__ __forceinline__ void source_fill_op(const T *__restrict input, T *__r
8258
}
8359
break;
8460
}
85-
default:
86-
for (uint32_t elem_idx = threadIdx.x; elem_idx < chunk_len; elem_idx += blockDim.x) {
87-
smem[elem_idx] = input[chunk_start + elem_idx];
88-
}
89-
break;
61+
default: __builtin_unreachable();
9062
}
9163
}
9264

9365
template <typename T>
94-
__device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict output, uint64_t array_len,
95-
const DynamicOp *__restrict ops, uint8_t num_ops) {
96-
assert(num_ops <= MAX_DECODE_OPS);
66+
__device__ __forceinline__ T dynamic_scalar_op(T value, const struct ScalarOp &op) {
67+
switch (op.op_code) {
68+
case ScalarOp::FOR: {
69+
return value + static_cast<T>(op.params.frame_of_ref.reference);
70+
}
71+
case ScalarOp::ZIGZAG: {
72+
return (value >> 1) ^ static_cast<T>(-(value & 1));
73+
}
74+
case ScalarOp::ALP: {
75+
float result = static_cast<float>(static_cast<int32_t>(value)) * op.params.alp.f * op.params.alp.e;
76+
return static_cast<T>(__float_as_uint(result));
77+
}
78+
default: __builtin_unreachable();
79+
}
80+
}
9781

82+
template <typename T>
83+
__device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict output, uint64_t array_len,
84+
const struct DynamicDispatchPlan *__restrict plan) {
9885
constexpr uint32_t ELEMENTS_PER_BLOCK = 2048;
9986
constexpr uint32_t VALUES_PER_LOOP = 32 / sizeof(T);
10087

101-
__shared__ DynamicOp smem_ops[MAX_DECODE_OPS];
88+
__shared__ struct DynamicDispatchPlan smem_plan;
10289
__shared__ T smem_values[FL_CHUNK_SIZE];
10390

104-
// Cache ops in shared memory.
105-
if (threadIdx.x < num_ops) {
106-
smem_ops[threadIdx.x] = ops[threadIdx.x];
107-
}
91+
// Cache the plan in shared memory.
92+
if (threadIdx.x == 0) smem_plan = *plan;
10893
__syncthreads();
10994

11095
const uint64_t block_start = static_cast<uint64_t>(blockIdx.x) * ELEMENTS_PER_BLOCK;
11196
const uint64_t block_end = min(block_start + ELEMENTS_PER_BLOCK, array_len);
11297

11398
for (uint64_t chunk_start = block_start; chunk_start < block_end; chunk_start += FL_CHUNK_SIZE) {
114-
const uint32_t chunk_len =
115-
static_cast<uint32_t>(min(static_cast<uint64_t>(FL_CHUNK_SIZE), block_end - chunk_start));
116-
117-
source_fill_op<T>(input, smem_values, chunk_start, chunk_len, smem_ops[0]);
99+
const uint32_t chunk_len = min(FL_CHUNK_SIZE, static_cast<uint32_t>(block_end - chunk_start));
100+
dynamic_source_op<T>(input, smem_values, chunk_start, chunk_len, smem_plan.source);
118101
__syncthreads();
119102

120103
const uint32_t tile_size = blockDim.x * VALUES_PER_LOOP;
121104
const uint32_t num_full_tiles = chunk_len / tile_size;
122-
const uint8_t scalar_op_start_idx = is_source_op(smem_ops[0].op);
123105

124106
for (uint32_t tile = 0; tile < num_full_tiles; ++tile) {
125107
const uint32_t tile_base = tile * tile_size;
@@ -134,12 +116,12 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
134116
values[idx] = smem_values[tile_base + idx * blockDim.x + threadIdx.x];
135117
}
136118

137-
for (uint8_t op_idx = scalar_op_start_idx; op_idx < num_ops; ++op_idx) {
138-
const DynamicOp &decode_op = smem_ops[op_idx];
119+
for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) {
120+
const struct ScalarOp &scalar_op = smem_plan.scalar_ops[op_idx];
139121

140122
#pragma unroll
141123
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
142-
values[idx] = apply_scalar_op(values[idx], decode_op);
124+
values[idx] = dynamic_scalar_op(values[idx], scalar_op);
143125
}
144126
}
145127

@@ -153,8 +135,8 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
153135
const uint32_t rem_start = num_full_tiles * tile_size;
154136
for (uint32_t elem_idx = rem_start + threadIdx.x; elem_idx < chunk_len; elem_idx += blockDim.x) {
155137
T val = smem_values[elem_idx];
156-
for (uint8_t op_idx = scalar_op_start_idx; op_idx < num_ops; ++op_idx) {
157-
val = apply_scalar_op(val, smem_ops[op_idx]);
138+
for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) {
139+
val = dynamic_scalar_op(val, smem_plan.scalar_ops[op_idx]);
158140
}
159141
output[chunk_start + elem_idx] = val;
160142
}
@@ -166,8 +148,8 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
166148
#define GENERATE_DYNAMIC_DISPATCH_KERNEL(suffix, Type) \
167149
extern "C" __global__ void dynamic_dispatch_##suffix(const Type *__restrict input, \
168150
Type *__restrict output, uint64_t array_len, \
169-
const DynamicOp *__restrict ops, uint8_t num_ops) { \
170-
dynamic_dispatch_impl<Type>(input, output, array_len, ops, num_ops); \
151+
const struct DynamicDispatchPlan *__restrict plan) { \
152+
dynamic_dispatch_impl<Type>(input, output, array_len, plan); \
171153
}
172154

173155
FOR_EACH_INTEGER(GENERATE_DYNAMIC_DISPATCH_KERNEL)

0 commit comments

Comments
 (0)