Skip to content

Commit 76b5afc

Browse files
committed
feat: dyn dispatch plan
1 parent 3c49fae commit 76b5afc

7 files changed

Lines changed: 243 additions & 236 deletions

File tree

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ use vortex_cuda::CudaExecutionCtx;
2626
use vortex_cuda::CudaSession;
2727
use vortex_cuda::bitpacked_cuda_kernel;
2828
use vortex_cuda::bitpacked_cuda_launch_config;
29-
use vortex_cuda::dynamic_dispatch_op::DynamicOp;
30-
use vortex_cuda::dynamic_dispatch_op::DynamicOpCode_ALP;
31-
use vortex_cuda::dynamic_dispatch_op::DynamicOpCode_BITUNPACK;
32-
use vortex_cuda::dynamic_dispatch_op::DynamicOpCode_FOR;
29+
use vortex_cuda::dynamic_dispatch::DynamicDispatchPlan;
30+
use vortex_cuda::dynamic_dispatch::ScalarOp;
31+
use vortex_cuda::dynamic_dispatch::SourceOp;
3332
use vortex_cuda_macros::cuda_available;
3433
use vortex_cuda_macros::cuda_not_available;
3534
use vortex_dtype::PType;
@@ -54,10 +53,6 @@ const ALP_E: f32 = 1.0;
5453
// Helpers
5554
// ---------------------------------------------------------------------------
5655

57-
fn pack_alp_f32_param(f: f32, e: f32) -> u64 {
58-
(e.to_bits() as u64) << 32 | f.to_bits() as u64
59-
}
60-
6156
/// Helper: launch a single FoR kernel on a device buffer (in-place).
6257
fn launch_for_kernel(
6358
cuda_ctx: &mut CudaExecutionCtx,
@@ -107,12 +102,11 @@ fn run_dynamic_dispatch_timed(
107102
input_ptr: u64,
108103
output_ptr: u64,
109104
array_len: usize,
110-
device_ops: &Arc<cudarc::driver::CudaSlice<DynamicOp>>,
111-
num_ops: u8,
105+
device_plan: &Arc<cudarc::driver::CudaSlice<DynamicDispatchPlan>>,
112106
) -> VortexResult<Duration> {
113107
let cuda_function = cuda_ctx.load_function("dynamic_dispatch", &["u32"])?;
114108
let array_len_u64 = array_len as u64;
115-
let ops_ptr = device_ops.device_ptr(cuda_ctx.stream()).0;
109+
let plan_ptr = device_plan.device_ptr(cuda_ctx.stream()).0;
116110

117111
let stream = cuda_ctx.stream();
118112
let ctx = stream.context();
@@ -127,8 +121,7 @@ fn run_dynamic_dispatch_timed(
127121
launch_builder.arg(&input_ptr);
128122
launch_builder.arg(&output_ptr);
129123
launch_builder.arg(&array_len_u64);
130-
launch_builder.arg(&ops_ptr);
131-
launch_builder.arg(&num_ops);
124+
launch_builder.arg(&plan_ptr);
132125

133126
let num_blocks = array_len.div_ceil(2048) as u32;
134127
let config = LaunchConfig {
@@ -275,15 +268,14 @@ fn bench_bitunpack_for_separate(c: &mut Criterion) {
275268
}
276269

277270
// ============================================================================
278-
// Benchmark: BitUnpack + FoR — single fused dynamic scalar_decode launch
271+
// Benchmark: BitUnpack + FoR — single fused dynamic dispatch launch
279272
// ============================================================================
280273

281274
/// Run a fused dynamic_dispatch launch on a bitpacked array, returning GPU time.
282275
fn run_dynamic_dispatch_bitpacked_timed(
283276
cuda_ctx: &mut CudaExecutionCtx,
284277
bitpacked_array: &BitPackedArray,
285-
device_ops: &Arc<cudarc::driver::CudaSlice<DynamicOp>>,
286-
num_ops: u8,
278+
device_plan: &Arc<cudarc::driver::CudaSlice<DynamicDispatchPlan>>,
287279
) -> VortexResult<Duration> {
288280
let packed = bitpacked_array.packed().clone();
289281
let len = bitpacked_array.len();
@@ -314,24 +306,17 @@ fn run_dynamic_dispatch_bitpacked_timed(
314306
.synchronize()
315307
.map_err(|e| vortex_err!("failed to synchronize stream: {:?}", e))?;
316308

317-
run_dynamic_dispatch_timed(cuda_ctx, input_ptr, output_ptr, len, device_ops, num_ops)
309+
run_dynamic_dispatch_timed(cuda_ctx, input_ptr, output_ptr, len, device_plan)
318310
}
319311

320312
fn bench_bitunpack_for_dynamic_dispatch(c: &mut Criterion) {
321313
let mut group = c.benchmark_group("bitunpack_for");
322314
group.sample_size(10);
323315

324-
// ops = [BITUNPACK(bit_width=BIT_WIDTH), FOR(REFERENCE_VALUE)]
325-
let ops = vec![
326-
DynamicOp {
327-
op: DynamicOpCode_BITUNPACK,
328-
param: BIT_WIDTH as u64,
329-
},
330-
DynamicOp {
331-
op: DynamicOpCode_FOR,
332-
param: REFERENCE_VALUE as u64,
333-
},
334-
];
316+
let plan = DynamicDispatchPlan::new(
317+
SourceOp::bitunpack(BIT_WIDTH),
318+
&[ScalarOp::frame_of_ref(REFERENCE_VALUE as u64)],
319+
);
335320

336321
for (len, len_str) in BENCH_ARGS {
337322
group.throughput(Throughput::Bytes((len * size_of::<u32>()) as u64));
@@ -350,11 +335,11 @@ fn bench_bitunpack_for_dynamic_dispatch(c: &mut Criterion) {
350335
.load_function("dynamic_dispatch", &["u32"])
351336
.vortex_expect("failed to preload dynamic_dispatch kernel");
352337

353-
let device_ops = Arc::new(
338+
let device_plan = Arc::new(
354339
cuda_ctx
355340
.stream()
356-
.clone_htod(ops.as_slice())
357-
.expect("failed to copy ops to device"),
341+
.clone_htod(std::slice::from_ref(&plan))
342+
.expect("failed to copy plan to device"),
358343
);
359344

360345
b.iter_custom(|iters| {
@@ -364,8 +349,7 @@ fn bench_bitunpack_for_dynamic_dispatch(c: &mut Criterion) {
364349
let kernel_time = run_dynamic_dispatch_bitpacked_timed(
365350
&mut cuda_ctx,
366351
array,
367-
&device_ops,
368-
ops.len() as u8,
352+
&device_plan,
369353
)
370354
.vortex_expect("bitunpack+for dynamic_dispatch failed");
371355
total_time += kernel_time;
@@ -388,21 +372,13 @@ fn bench_bitunpack_for_alp_dynamic_dispatch(c: &mut Criterion) {
388372
let mut group = c.benchmark_group("bitunpack_for_alp");
389373
group.sample_size(10);
390374

391-
// ops = [BITUNPACK(bit_width), FOR(reference), ALP(f, e)]
392-
let ops = vec![
393-
DynamicOp {
394-
op: DynamicOpCode_BITUNPACK,
395-
param: BIT_WIDTH as u64,
396-
},
397-
DynamicOp {
398-
op: DynamicOpCode_FOR,
399-
param: REFERENCE_VALUE as u64,
400-
},
401-
DynamicOp {
402-
op: DynamicOpCode_ALP,
403-
param: pack_alp_f32_param(ALP_F, ALP_E),
404-
},
405-
];
375+
let plan = DynamicDispatchPlan::new(
376+
SourceOp::bitunpack(BIT_WIDTH),
377+
&[
378+
ScalarOp::frame_of_ref(REFERENCE_VALUE as u64),
379+
ScalarOp::alp(ALP_F, ALP_E),
380+
],
381+
);
406382

407383
for (len, len_str) in BENCH_ARGS {
408384
group.throughput(Throughput::Bytes((len * size_of::<u32>()) as u64));
@@ -421,11 +397,11 @@ fn bench_bitunpack_for_alp_dynamic_dispatch(c: &mut Criterion) {
421397
.load_function("dynamic_dispatch", &["u32"])
422398
.vortex_expect("failed to preload dynamic_dispatch kernel");
423399

424-
let device_ops = Arc::new(
400+
let device_plan = Arc::new(
425401
cuda_ctx
426402
.stream()
427-
.clone_htod(ops.as_slice())
428-
.expect("failed to copy ops to device"),
403+
.clone_htod(std::slice::from_ref(&plan))
404+
.expect("failed to copy plan to device"),
429405
);
430406

431407
b.iter_custom(|iters| {
@@ -435,8 +411,7 @@ fn bench_bitunpack_for_alp_dynamic_dispatch(c: &mut Criterion) {
435411
let kernel_time = run_dynamic_dispatch_bitpacked_timed(
436412
&mut cuda_ctx,
437413
array,
438-
&device_ops,
439-
ops.len() as u8,
414+
&device_plan,
440415
)
441416
.vortex_expect("bitunpack+for+alp dynamic_dispatch failed");
442417
total_time += kernel_time;

vortex-cuda/build.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,24 +182,32 @@ fn nvcc_compile_ptx(
182182

183183
/// Generate bindings for the dynamic dispatch shared header.
184184
///
185-
/// `DynamicOp` and `DynamicOpCode` are shared between CUDA kernels
185+
/// `DynamicDispatchPlan` and related types are shared between CUDA kernels
186186
/// and Rust host code.
187187
fn generate_dynamic_dispatch_bindings(kernels_src: &Path, out_dir: &Path) {
188188
let header = kernels_src.join("dynamic_dispatch.h");
189189
println!("cargo:rerun-if-changed={}", header.display());
190190

191191
let bindings = bindgen::Builder::default()
192192
.header(header.to_string_lossy())
193-
.allowlist_type("DynamicOp")
194-
.allowlist_type("DynamicOpCode")
193+
.allowlist_type("DynamicDispatchPlan")
194+
.allowlist_type("SourceOp")
195+
.allowlist_type("SourceOpCode")
196+
.allowlist_type("SourceParams")
197+
.allowlist_type("BitunpackParams")
198+
.allowlist_type("ScalarOp")
199+
.allowlist_type("ScalarOpCode")
200+
.allowlist_type("ScalarParams")
201+
.allowlist_type("FoRParams")
202+
.allowlist_type("AlpParams")
195203
.derive_copy(true)
196204
.derive_debug(true)
197205
.generate()
198206
.expect("Failed to generate dynamic_dispatch bindings");
199207

200208
bindings
201-
.write_to_file(out_dir.join("dynamic_dispatch_op.rs"))
202-
.expect("Failed to write dynamic_dispatch_op.rs");
209+
.write_to_file(out_dir.join("dynamic_dispatch.rs"))
210+
.expect("Failed to write dynamic_dispatch.rs");
203211
}
204212

205213
/// Check if CUDA is available based on nvcc.

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
// Dynamic dispatch kernel: decodes an array by applying a sequence of operations
5-
// in a single kernel launch. The first op may optionally be a "source" op, e.g. bitunpack.
6-
// Subsequent transform ops are applied element-wise in registers.
5+
// in a single kernel launch. The source op fills shared memory (e.g. bitunpack),
6+
// then scalar ops are applied element-wise in registers (e.g. FoR, zigzag, ALP).
77

88
#include <assert.h>
99
#include <cuda.h>
@@ -17,26 +17,19 @@
1717
#include "dynamic_dispatch.h"
1818
#include "types.cuh"
1919

20-
constexpr uint8_t MAX_DECODE_OPS = 8;
2120
constexpr uint32_t FL_CHUNK_SIZE = 1024;
2221

23-
__device__ __forceinline__ bool is_source_op(enum DynamicOpCode op) {
24-
return op == BITUNPACK;
25-
}
26-
2722
template <typename T>
28-
__device__ __forceinline__ T apply_scalar_op(T value, const DynamicOp &op) {
29-
switch (op.op) {
30-
case FOR: {
31-
return value + static_cast<T>(op.param);
23+
__device__ __forceinline__ T apply_scalar_op(T value, const struct ScalarOp &op) {
24+
switch (op.op_code) {
25+
case ScalarOp::FOR: {
26+
return value + static_cast<T>(op.params.frame_of_ref.reference);
3227
}
33-
case ZIGZAG: {
28+
case ScalarOp::ZIGZAG: {
3429
return (value >> 1) ^ static_cast<T>(-(value & 1));
3530
}
36-
case ALP: {
37-
float f_val = __uint_as_float(static_cast<uint32_t>(op.param));
38-
float e_val = __uint_as_float(static_cast<uint32_t>(op.param >> 32));
39-
float result = static_cast<float>(static_cast<int32_t>(value)) * f_val * e_val;
31+
case ScalarOp::ALP: {
32+
float result = static_cast<float>(static_cast<int32_t>(value)) * op.params.alp.f * op.params.alp.e;
4033
return static_cast<T>(__float_as_uint(result));
4134
}
4235
default: __builtin_unreachable();
@@ -67,13 +60,13 @@ BITUNPACK_LANE(64, uint64_t, int64_t)
6760
template <typename T>
6861
__device__ __forceinline__ void source_fill_op(const T *__restrict input, T *__restrict smem,
6962
uint64_t chunk_start, uint32_t chunk_len,
70-
const DynamicOp &source_op) {
63+
const struct SourceOp &source_op) {
7164
constexpr uint32_t T_BITS = sizeof(T) * 8;
7265
constexpr uint32_t FL_LANES = FL_CHUNK_SIZE / T_BITS;
7366

74-
switch (source_op.op) {
75-
case BITUNPACK: {
76-
const uint32_t bit_width = static_cast<uint32_t>(source_op.param);
67+
switch (source_op.op_code) {
68+
case SourceOp::BITUNPACK: {
69+
const uint32_t bit_width = source_op.params.bitunpack.bit_width;
7770
const uint32_t packed_words_per_chunk = FL_LANES * bit_width;
7871
const uint64_t chunk_idx = chunk_start / FL_CHUNK_SIZE;
7972
const T *packed_chunk = input + chunk_idx * packed_words_per_chunk;
@@ -82,44 +75,45 @@ __device__ __forceinline__ void source_fill_op(const T *__restrict input, T *__r
8275
}
8376
break;
8477
}
85-
default:
86-
for (uint32_t elem_idx = threadIdx.x; elem_idx < chunk_len; elem_idx += blockDim.x) {
87-
smem[elem_idx] = input[chunk_start + elem_idx];
88-
}
89-
break;
78+
default: __builtin_unreachable();
9079
}
9180
}
9281

9382
template <typename T>
9483
__device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict output, uint64_t array_len,
95-
const DynamicOp *__restrict ops, uint8_t num_ops) {
96-
assert(num_ops <= MAX_DECODE_OPS);
97-
84+
const struct DynamicDispatchPlan *__restrict plan) {
9885
constexpr uint32_t ELEMENTS_PER_BLOCK = 2048;
9986
constexpr uint32_t VALUES_PER_LOOP = 32 / sizeof(T);
10087

101-
__shared__ DynamicOp smem_ops[MAX_DECODE_OPS];
88+
__shared__ struct SourceOp smem_source;
89+
__shared__ uint8_t smem_num_scalar_ops;
90+
__shared__ struct ScalarOp smem_scalar_ops[MAX_SCALAR_OPS];
10291
__shared__ T smem_values[FL_CHUNK_SIZE];
10392

104-
// Cache ops in shared memory.
105-
if (threadIdx.x < num_ops) {
106-
smem_ops[threadIdx.x] = ops[threadIdx.x];
93+
// Cache the plan in shared memory.
94+
if (threadIdx.x == 0) {
95+
smem_source = plan->source;
96+
smem_num_scalar_ops = plan->num_scalar_ops;
97+
}
98+
if (threadIdx.x < plan->num_scalar_ops) {
99+
smem_scalar_ops[threadIdx.x] = plan->scalar_ops[threadIdx.x];
107100
}
108101
__syncthreads();
109102

103+
assert(smem_num_scalar_ops <= MAX_SCALAR_OPS);
104+
110105
const uint64_t block_start = static_cast<uint64_t>(blockIdx.x) * ELEMENTS_PER_BLOCK;
111106
const uint64_t block_end = min(block_start + ELEMENTS_PER_BLOCK, array_len);
112107

113108
for (uint64_t chunk_start = block_start; chunk_start < block_end; chunk_start += FL_CHUNK_SIZE) {
114109
const uint32_t chunk_len =
115110
static_cast<uint32_t>(min(static_cast<uint64_t>(FL_CHUNK_SIZE), block_end - chunk_start));
116111

117-
source_fill_op<T>(input, smem_values, chunk_start, chunk_len, smem_ops[0]);
112+
source_fill_op<T>(input, smem_values, chunk_start, chunk_len, smem_source);
118113
__syncthreads();
119114

120115
const uint32_t tile_size = blockDim.x * VALUES_PER_LOOP;
121116
const uint32_t num_full_tiles = chunk_len / tile_size;
122-
const uint8_t scalar_op_start_idx = is_source_op(smem_ops[0].op);
123117

124118
for (uint32_t tile = 0; tile < num_full_tiles; ++tile) {
125119
const uint32_t tile_base = tile * tile_size;
@@ -134,12 +128,12 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
134128
values[idx] = smem_values[tile_base + idx * blockDim.x + threadIdx.x];
135129
}
136130

137-
for (uint8_t op_idx = scalar_op_start_idx; op_idx < num_ops; ++op_idx) {
138-
const DynamicOp &decode_op = smem_ops[op_idx];
131+
for (uint8_t op_idx = 0; op_idx < smem_num_scalar_ops; ++op_idx) {
132+
const struct ScalarOp &scalar_op = smem_scalar_ops[op_idx];
139133

140134
#pragma unroll
141135
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
142-
values[idx] = apply_scalar_op(values[idx], decode_op);
136+
values[idx] = apply_scalar_op(values[idx], scalar_op);
143137
}
144138
}
145139

@@ -153,8 +147,8 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
153147
const uint32_t rem_start = num_full_tiles * tile_size;
154148
for (uint32_t elem_idx = rem_start + threadIdx.x; elem_idx < chunk_len; elem_idx += blockDim.x) {
155149
T val = smem_values[elem_idx];
156-
for (uint8_t op_idx = scalar_op_start_idx; op_idx < num_ops; ++op_idx) {
157-
val = apply_scalar_op(val, smem_ops[op_idx]);
150+
for (uint8_t op_idx = 0; op_idx < smem_num_scalar_ops; ++op_idx) {
151+
val = apply_scalar_op(val, smem_scalar_ops[op_idx]);
158152
}
159153
output[chunk_start + elem_idx] = val;
160154
}
@@ -166,8 +160,8 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
166160
#define GENERATE_DYNAMIC_DISPATCH_KERNEL(suffix, Type) \
167161
extern "C" __global__ void dynamic_dispatch_##suffix(const Type *__restrict input, \
168162
Type *__restrict output, uint64_t array_len, \
169-
const DynamicOp *__restrict ops, uint8_t num_ops) { \
170-
dynamic_dispatch_impl<Type>(input, output, array_len, ops, num_ops); \
163+
const struct DynamicDispatchPlan *__restrict plan) { \
164+
dynamic_dispatch_impl<Type>(input, output, array_len, plan); \
171165
}
172166

173167
FOR_EACH_INTEGER(GENERATE_DYNAMIC_DISPATCH_KERNEL)

0 commit comments

Comments
 (0)