22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
44// Dynamic dispatch kernel: decodes an array by applying a sequence of operations
5- // in a single kernel launch. The first op may optionally be a "source" op, e.g. bitunpack.
6- // Subsequent transform ops are applied element-wise in registers.
5+ // in a single kernel launch. The source op fills shared memory ( e.g. bitunpack),
6+ // then scalar ops are applied element-wise in registers (e.g. FoR, zigzag, ALP) .
77
88#include < assert.h>
99#include < cuda.h>
1717#include " dynamic_dispatch.h"
1818#include " types.cuh"
1919
20- constexpr uint8_t MAX_DECODE_OPS = 8 ;
2120constexpr uint32_t FL_CHUNK_SIZE = 1024 ;
2221
23- __device__ __forceinline__ bool is_source_op (enum DynamicOpCode op) {
24- return op == BITUNPACK;
25- }
26-
27- template <typename T>
28- __device__ __forceinline__ T apply_scalar_op (T value, const DynamicOp &op) {
29- switch (op.op ) {
30- case FOR: {
31- return value + static_cast <T>(op.param );
32- }
33- case ZIGZAG: {
34- return (value >> 1 ) ^ static_cast <T>(-(value & 1 ));
35- }
36- case ALP: {
37- float f_val = __uint_as_float (static_cast <uint32_t >(op.param ));
38- float e_val = __uint_as_float (static_cast <uint32_t >(op.param >> 32 ));
39- float result = static_cast <float >(static_cast <int32_t >(value)) * f_val * e_val;
40- return static_cast <T>(__float_as_uint (result));
41- }
42- default : __builtin_unreachable ();
43- }
44- }
45-
4622template <typename T>
4723__device__ __forceinline__ void bitunpack_lane_to_smem (const T *__restrict packed_chunk, T *__restrict smem,
4824 unsigned int lane, uint32_t bit_width);
@@ -65,15 +41,15 @@ BITUNPACK_LANE(64, uint64_t, uint64_t)
6541BITUNPACK_LANE(64 , uint64_t , int64_t )
6642
6743template <typename T>
68- __device__ __forceinline__ void source_fill_op (const T *__restrict input, T *__restrict smem,
44+ __device__ __forceinline__ void dynamic_source_op (const T *__restrict input, T *__restrict smem,
6945 uint64_t chunk_start, uint32_t chunk_len,
70- const DynamicOp &source_op) {
46+ const struct SourceOp &source_op) {
7147 constexpr uint32_t T_BITS = sizeof (T) * 8 ;
7248 constexpr uint32_t FL_LANES = FL_CHUNK_SIZE / T_BITS;
7349
74- switch (source_op.op ) {
75- case BITUNPACK: {
76- const uint32_t bit_width = static_cast < uint32_t >( source_op.param ) ;
50+ switch (source_op.op_code ) {
51+ case SourceOp:: BITUNPACK: {
52+ const uint32_t bit_width = source_op.params . bitunpack . bit_width ;
7753 const uint32_t packed_words_per_chunk = FL_LANES * bit_width;
7854 const uint64_t chunk_idx = chunk_start / FL_CHUNK_SIZE;
7955 const T *packed_chunk = input + chunk_idx * packed_words_per_chunk;
@@ -82,44 +58,50 @@ __device__ __forceinline__ void source_fill_op(const T *__restrict input, T *__r
8258 }
8359 break ;
8460 }
85- default :
86- for (uint32_t elem_idx = threadIdx .x ; elem_idx < chunk_len; elem_idx += blockDim .x ) {
87- smem[elem_idx] = input[chunk_start + elem_idx];
88- }
89- break ;
61+ default : __builtin_unreachable ();
9062 }
9163}
9264
9365template <typename T>
94- __device__ void dynamic_dispatch_impl (const T *__restrict input, T *__restrict output, uint64_t array_len,
95- const DynamicOp *__restrict ops, uint8_t num_ops) {
96- assert (num_ops <= MAX_DECODE_OPS);
66+ __device__ __forceinline__ T dynamic_scalar_op (T value, const struct ScalarOp &op) {
67+ switch (op.op_code ) {
68+ case ScalarOp::FOR: {
69+ return value + static_cast <T>(op.params .frame_of_ref .reference );
70+ }
71+ case ScalarOp::ZIGZAG: {
72+ return (value >> 1 ) ^ static_cast <T>(-(value & 1 ));
73+ }
74+ case ScalarOp::ALP: {
75+ float result = static_cast <float >(static_cast <int32_t >(value)) * op.params .alp .f * op.params .alp .e ;
76+ return static_cast <T>(__float_as_uint (result));
77+ }
78+ default : __builtin_unreachable ();
79+ }
80+ }
9781
82+ template <typename T>
83+ __device__ void dynamic_dispatch_impl (const T *__restrict input, T *__restrict output, uint64_t array_len,
84+ const struct DynamicDispatchPlan *__restrict plan) {
9885 constexpr uint32_t ELEMENTS_PER_BLOCK = 2048 ;
9986 constexpr uint32_t VALUES_PER_LOOP = 32 / sizeof (T);
10087
101- __shared__ DynamicOp smem_ops[MAX_DECODE_OPS] ;
88+ __shared__ struct DynamicDispatchPlan smem_plan ;
10289 __shared__ T smem_values[FL_CHUNK_SIZE];
10390
104- // Cache ops in shared memory.
105- if (threadIdx .x < num_ops) {
106- smem_ops[threadIdx .x ] = ops[threadIdx .x ];
107- }
91+ // Cache the plan in shared memory.
92+ if (threadIdx .x == 0 ) smem_plan = *plan;
10893 __syncthreads ();
10994
11095 const uint64_t block_start = static_cast <uint64_t >(blockIdx .x ) * ELEMENTS_PER_BLOCK;
11196 const uint64_t block_end = min (block_start + ELEMENTS_PER_BLOCK, array_len);
11297
11398 for (uint64_t chunk_start = block_start; chunk_start < block_end; chunk_start += FL_CHUNK_SIZE) {
114- const uint32_t chunk_len =
115- static_cast <uint32_t >(min (static_cast <uint64_t >(FL_CHUNK_SIZE), block_end - chunk_start));
116-
117- source_fill_op<T>(input, smem_values, chunk_start, chunk_len, smem_ops[0 ]);
99+ const uint32_t chunk_len = min (FL_CHUNK_SIZE, static_cast <uint32_t >(block_end - chunk_start));
100+ dynamic_source_op<T>(input, smem_values, chunk_start, chunk_len, smem_plan.source );
118101 __syncthreads ();
119102
120103 const uint32_t tile_size = blockDim .x * VALUES_PER_LOOP;
121104 const uint32_t num_full_tiles = chunk_len / tile_size;
122- const uint8_t scalar_op_start_idx = is_source_op (smem_ops[0 ].op );
123105
124106 for (uint32_t tile = 0 ; tile < num_full_tiles; ++tile) {
125107 const uint32_t tile_base = tile * tile_size;
@@ -134,12 +116,12 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
134116 values[idx] = smem_values[tile_base + idx * blockDim .x + threadIdx .x ];
135117 }
136118
137- for (uint8_t op_idx = scalar_op_start_idx ; op_idx < num_ops ; ++op_idx) {
138- const DynamicOp &decode_op = smem_ops [op_idx];
119+ for (uint8_t op_idx = 0 ; op_idx < smem_plan. num_scalar_ops ; ++op_idx) {
120+ const struct ScalarOp &scalar_op = smem_plan. scalar_ops [op_idx];
139121
140122 #pragma unroll
141123 for (uint32_t idx = 0 ; idx < VALUES_PER_LOOP; ++idx) {
142- values[idx] = apply_scalar_op (values[idx], decode_op );
124+ values[idx] = dynamic_scalar_op (values[idx], scalar_op );
143125 }
144126 }
145127
@@ -153,8 +135,8 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
153135 const uint32_t rem_start = num_full_tiles * tile_size;
154136 for (uint32_t elem_idx = rem_start + threadIdx .x ; elem_idx < chunk_len; elem_idx += blockDim .x ) {
155137 T val = smem_values[elem_idx];
156- for (uint8_t op_idx = scalar_op_start_idx ; op_idx < num_ops ; ++op_idx) {
157- val = apply_scalar_op (val, smem_ops [op_idx]);
138+ for (uint8_t op_idx = 0 ; op_idx < smem_plan. num_scalar_ops ; ++op_idx) {
139+ val = dynamic_scalar_op (val, smem_plan. scalar_ops [op_idx]);
158140 }
159141 output[chunk_start + elem_idx] = val;
160142 }
@@ -166,8 +148,8 @@ __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict o
166148#define GENERATE_DYNAMIC_DISPATCH_KERNEL (suffix, Type ) \
167149 extern " C" __global__ void dynamic_dispatch_##suffix(const Type *__restrict input, \
168150 Type *__restrict output, uint64_t array_len, \
169- const DynamicOp *__restrict ops, uint8_t num_ops ) { \
170- dynamic_dispatch_impl<Type>(input, output, array_len, ops, num_ops); \
151+ const struct DynamicDispatchPlan *__restrict plan ) { \
152+ dynamic_dispatch_impl<Type>(input, output, array_len, plan); \
171153 }
172154
173155FOR_EACH_INTEGER (GENERATE_DYNAMIC_DISPATCH_KERNEL)
0 commit comments