Skip to content

Commit 1a28a3c

Browse files
committed
feat: make dispatch plan dynamically sized
Replace the fixed-size DynamicDispatchPlan struct with a variable-length packed byte buffer, removing the MAX_STAGES and MAX_SCALAR_OPS limits. Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent fa6931a commit 1a28a3c

File tree

7 files changed

+317
-187
lines changed

7 files changed

+317
-187
lines changed

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ fn run_timed(
5959
cuda_ctx: &mut CudaExecutionCtx,
6060
array_len: usize,
6161
output_buf: &CudaDeviceBuffer,
62-
device_plan: &Arc<cudarc::driver::CudaSlice<CudaDispatchPlan>>,
62+
device_plan: &Arc<cudarc::driver::CudaSlice<u8>>,
6363
shared_mem_bytes: u32,
6464
) -> VortexResult<Duration> {
6565
let cuda_function = cuda_ctx.load_function("dynamic_dispatch", &[PType::U32])?;
@@ -115,8 +115,7 @@ struct BenchRunner {
115115
_plan: CudaDispatchPlan,
116116
smem_bytes: u32,
117117
len: usize,
118-
// Keep alive
119-
device_plan: Arc<cudarc::driver::CudaSlice<CudaDispatchPlan>>,
118+
device_plan: Arc<cudarc::driver::CudaSlice<u8>>,
120119
output_buf: CudaDeviceBuffer,
121120
_plan_buffers: Vec<vortex::array::buffer::BufferHandle>,
122121
}
@@ -134,7 +133,7 @@ impl BenchRunner {
134133
let device_plan = Arc::new(
135134
cuda_ctx
136135
.stream()
137-
.clone_htod(std::slice::from_ref(&dispatch_plan))
136+
.clone_htod(dispatch_plan.as_bytes())
138137
.expect("htod plan"),
139138
);
140139

vortex-cuda/build.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,6 @@ fn nvcc_compile_ptx(
183183
}
184184

185185
/// Generate bindings for the dynamic dispatch shared header.
186-
///
187-
/// `DynamicDispatchPlan` and related types are shared between CUDA kernels
188-
/// and Rust host code.
189186
fn generate_dynamic_dispatch_bindings(kernels_src: &Path, out_dir: &Path) {
190187
let header = kernels_src.join("dynamic_dispatch.h");
191188
println!("cargo:rerun-if-changed={}", header.display());

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ __device__ void apply_scalar_ops(const T *__restrict smem_input,
262262
/// final results to `write_dest` at `write_offset`. Input stages write
263263
/// back to smem; the output stage writes to global memory.
264264
template <typename T, StorePolicy S>
265-
__device__ void execute_stage(const struct Stage &stage,
265+
__device__ void execute_stage(const Stage &stage,
266266
T *__restrict smem_base,
267267
uint64_t chunk_start,
268268
uint32_t chunk_len,
@@ -293,7 +293,7 @@ __device__ void execute_stage(const struct Stage &stage,
293293
/// Each tile decodes exactly one FL block == SMEM_TILE_SIZE elements into
294294
/// shared memory. In case BITUNPACK is sliced, we need to account for the
295295
/// sub-byte element offset.
296-
__device__ inline uint32_t output_tile_len(const struct Stage &stage, uint32_t block_len, uint32_t tile_off) {
296+
__device__ inline uint32_t output_tile_len(const Stage &stage, uint32_t block_len, uint32_t tile_off) {
297297
const uint32_t element_offset = (tile_off == 0 && stage.source.op_code == SourceOp::BITUNPACK)
298298
? stage.source.params.bitunpack.element_offset
299299
: 0;
@@ -302,42 +302,28 @@ __device__ inline uint32_t output_tile_len(const struct Stage &stage, uint32_t b
302302

303303
/// Entry point of the dynamic dispatch kernel.
304304
///
305-
/// Executes the plan's stages in order:
306-
/// 1. Input stages populate shared memory with intermediate data
307-
/// for the output stage to reference.
308-
/// 2. The output stage decodes the root array and writes directly to
309-
/// global memory.
310-
///
311-
/// @param output Global memory output buffer
312-
/// @param array_len Total number of elements to produce
313-
/// @param plan Device pointer to the dispatch plan
305+
/// @param output Output buffer
306+
/// @param array_len Total number of elements to produce
307+
/// @param packed_plan Pointer to the packed plan byte buffer
314308
template <typename T>
315-
__device__ void dynamic_dispatch(T *__restrict output,
316-
uint64_t array_len,
317-
const struct DynamicDispatchPlan *__restrict plan) {
309+
__device__ void
310+
dynamic_dispatch(T *__restrict output, uint64_t array_len, const uint8_t *__restrict packed_plan) {
318311

319-
// Dynamically-sized shared memory: The host computes the exact byte count
320-
// needed to hold all stage outputs that must coexist simultaneously, and
321-
// passes the count at kernel launch (see DynamicDispatchPlan::shared_mem_bytes).
322312
extern __shared__ char smem_bytes[];
323313
T *smem_base = reinterpret_cast<T *>(smem_bytes);
324314

325-
__shared__ struct DynamicDispatchPlan smem_plan;
326-
if (threadIdx.x == 0) {
327-
smem_plan = *plan;
328-
}
329-
__syncthreads();
330-
331-
const uint8_t last = smem_plan.num_stages - 1;
315+
const auto *header = reinterpret_cast<const struct PlanHeader *>(packed_plan);
316+
const uint8_t *stage_cursor = packed_plan + sizeof(struct PlanHeader);
317+
const uint8_t last = header->num_stages - 1;
332318

333319
// Input stages: Decode inputs into smem regions.
334-
for (uint8_t i = 0; i < last; ++i) {
335-
const struct Stage &stage = smem_plan.stages[i];
320+
for (uint8_t idx = 0; idx < last; ++idx) {
321+
Stage stage = parse_stage(stage_cursor);
336322
T *smem_output = &smem_base[stage.smem_offset];
337323
execute_stage<T, StorePolicy::WRITEBACK>(stage, smem_base, 0, stage.len, smem_output, 0);
338324
}
339325

340-
const struct Stage &output_stage = smem_plan.stages[last];
326+
Stage output_stage = parse_stage(stage_cursor);
341327
const uint64_t block_start = static_cast<uint64_t>(blockIdx.x) * ELEMENTS_PER_BLOCK;
342328
const uint64_t block_end = min(block_start + ELEMENTS_PER_BLOCK, array_len);
343329
const uint32_t block_len = static_cast<uint32_t>(block_end - block_start);
@@ -356,11 +342,10 @@ __device__ void dynamic_dispatch(T *__restrict output,
356342

357343
/// Generates a dynamic dispatch kernel entry point for each unsigned integer type.
358344
#define GENERATE_DYNAMIC_DISPATCH_KERNEL(suffix, Type) \
359-
extern "C" __global__ void dynamic_dispatch_##suffix( \
360-
Type *__restrict output, \
361-
uint64_t array_len, \
362-
const struct DynamicDispatchPlan *__restrict plan) { \
363-
dynamic_dispatch<Type>(output, array_len, plan); \
345+
extern "C" __global__ void dynamic_dispatch_##suffix(Type *__restrict output, \
346+
uint64_t array_len, \
347+
const uint8_t *__restrict packed_plan) { \
348+
dynamic_dispatch<Type>(output, array_len, packed_plan); \
364349
}
365350

366351
FOR_EACH_UNSIGNED_INT(GENERATE_DYNAMIC_DISPATCH_KERNEL)

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,15 @@
66
/// The plan builder walks an encoding tree and emits a linear sequence of
77
/// stages. The kernel executes stages in order within a single launch.
88
///
9-
/// Shared memory: The plan builder bump-allocates shared memory regions for
10-
/// each input stage's output. The output stage (last) is placed after all
11-
/// input stages. Since all regions must coexist for the output stage to
12-
/// reference, the total shared memory is the end of whichever region extends
13-
/// furthest, in elements, times `sizeof(T)`.
9+
/// ## Stage plan
1410
///
15-
/// Example: RunEnd(ends=FoR(BitPacked), values=FoR(BitPacked)) with 100 runs
11+
/// The plan is packed as a variable-length byte buffer.
1612
///
17-
/// Stage 0 (input): BITUNPACK(7) → FoR(0) → smem[0..100) // run ends
18-
/// Stage 1 (input): BITUNPACK(10) → FoR(50) → smem[100..200) // run values
19-
/// Stage 2 (output): RUNEND(ends=0, values=100) → smem[200..1224) // resolved
20-
///
21-
/// shared_mem_bytes = (200 + 1024) * sizeof(T)
13+
/// Layout (contiguous bytes):
14+
/// [PlanHeader]
15+
/// [PackedStage 0][ScalarOp × N0]
16+
/// [PackedStage 1][ScalarOp × N1]
17+
/// ...
2218

2319
#pragma once
2420

@@ -27,12 +23,7 @@
2723
/// Elements processed per CUDA block.
2824
#define ELEMENTS_PER_BLOCK 2048
2925

30-
/// Shared memory tile size for the output stage. Each block decompresses
31-
/// ELEMENTS_PER_BLOCK elements but only holds SMEM_TILE_SIZE in smem at a
32-
/// time — each tile is written to global memory before the next is decoded
33-
/// into the same region. Input stages cannot tile because their outputs must
34-
/// remain accessible for random access (e.g., dictionary lookup, run-end
35-
/// binary search). Smaller tiles reduce smem per block, improving occupancy.
26+
/// Each tile is flushed to global before the next is decoded.
3627
#define SMEM_TILE_SIZE 1024
3728

3829
#ifdef __cplusplus
@@ -41,14 +32,13 @@ extern "C" {
4132

4233
/// Parameters for source ops, which decode data into a stage's shared memory region.
4334
union SourceParams {
44-
/// Unpack bit-packed data using FastLanes layout.
35+
/// Unpack FastLanes bit-packed data.
4536
struct BitunpackParams {
4637
uint8_t bit_width;
4738
uint32_t element_offset; // Sub-byte offset
4839
} bitunpack;
4940

50-
/// Copy elements verbatim from global memory to shared memory.
51-
/// The input pointer is pre-adjusted on the host to account for slicing.
41+
/// Copy from global to shared memory.
5242
struct LoadParams {
5343
uint8_t _placeholder;
5444
} load;
@@ -58,7 +48,7 @@ union SourceParams {
5848
uint32_t ends_smem_offset; // element offset to decoded ends in smem
5949
uint32_t values_smem_offset; // element offset to decoded values in smem
6050
uint64_t num_runs;
61-
uint64_t offset;
51+
uint64_t offset; // slice offset into the run-end encoded array
6252
} runend;
6353

6454
/// Generate a linear sequence: `value[i] = base + i * multiplier`.
@@ -96,38 +86,62 @@ struct ScalarOp {
9686
union ScalarParams params;
9787
};
9888

99-
#define MAX_SCALAR_OPS 4
100-
101-
/// A single stage in the dispatch plan.
102-
///
103-
/// Each stage is a pipeline (source + scalar ops) that writes decoded data
104-
/// into a shared memory region at `smem_offset`. Input stage outputs persist
105-
/// in smem so the output stage can reference them (via DICT or RUNEND offsets).
106-
struct Stage {
89+
/// Packed stage header, followed by `num_scalar_ops` inline ScalarOps.
90+
struct PackedStage {
10791
uint64_t input_ptr; // global memory pointer to this stage's encoded input
10892
uint32_t smem_offset; // element offset within dynamic shared memory for output
10993
uint32_t len; // number of elements this stage produces
11094

11195
struct SourceOp source;
11296
uint8_t num_scalar_ops;
113-
struct ScalarOp scalar_ops[MAX_SCALAR_OPS];
11497
};
11598

116-
#define MAX_STAGES 4
117-
118-
/// Dispatch plan: a sequence of stages.
119-
///
120-
/// The plan builder walks the encoding tree recursively, emitting an input
121-
/// stage each time it encounters a child array that needs to live in shared
122-
/// memory (e.g., dictionary values, run-end endpoints). Shared memory
123-
/// offsets are assigned with a simple bump allocator.
124-
///
125-
/// The last stage is the output pipeline which directly writes to global memory.
126-
struct DynamicDispatchPlan {
99+
/// Header for the packed plan byte buffer.
100+
struct __attribute__((aligned(8))) PlanHeader {
127101
uint8_t num_stages;
128-
struct Stage stages[MAX_STAGES];
102+
uint16_t plan_size_bytes; // total size of the packed plan including this header
129103
};
130104

131105
#ifdef __cplusplus
132106
}
133107
#endif
108+
109+
#ifdef __cplusplus
110+
111+
/// Stage parsed from the packed plan byte buffer.
112+
///
113+
/// Input stages decode data (e.g. dict values, run-end endpoints) into a
114+
/// shared memory region for the output stage to reference. The output stage
115+
/// decodes the root encoding and writes to global memory.
116+
struct Stage {
117+
uint64_t input_ptr; // encoded input in global memory
118+
uint32_t smem_offset; // output offset in shared memory (elements)
119+
uint32_t len; // elements produced
120+
struct SourceOp source; // source decode op
121+
uint8_t num_scalar_ops; // number of scalar ops
122+
const struct ScalarOp *scalar_ops; // scalar deoode ops
123+
};
124+
125+
/// Parse a single stage from the packed plan byte buffer and advance the cursor.
126+
///
127+
/// @param cursor Pointer into the packed plan buffer, pointing at a PackedStage.
128+
/// On return, advanced past this stage's ScalarOps.
129+
/// @return A Stage referencing data within the packed plan buffer.
130+
__device__ inline Stage parse_stage(const uint8_t *&cursor) {
131+
const auto *packed_stage = reinterpret_cast<const struct PackedStage *>(cursor);
132+
cursor += sizeof(struct PackedStage);
133+
134+
const auto *ops = reinterpret_cast<const struct ScalarOp *>(cursor);
135+
cursor += packed_stage->num_scalar_ops * sizeof(struct ScalarOp);
136+
137+
return Stage {
138+
.input_ptr = packed_stage->input_ptr,
139+
.smem_offset = packed_stage->smem_offset,
140+
.len = packed_stage->len,
141+
.source = packed_stage->source,
142+
.num_scalar_ops = packed_stage->num_scalar_ops,
143+
.scalar_ops = ops,
144+
};
145+
}
146+
147+
#endif

0 commit comments

Comments
 (0)