Skip to content

Commit d1a0e6c

Browse files
claudeconnortsui20
authored andcommitted
clean up inner_product constant-query fast paths after review
Self-review refactor; no behavior change. All 259 tests still pass, the `similarity_search` bench numbers are within run-to-run noise of the previous commits, and clippy + fmt are clean. - `try_execute_sorf_constant`: replace the `is::<ExactScalarFn<...>>()` followed by `as_opt::<...>` + `vortex_expect` dance with a single `as_opt` match. Removes the "just checked via is" expect and makes the side-detection read linearly. - `try_execute_dict_constant`: split into a trivial outer that tries each orientation and a new `try_execute_dict_constant_oriented` that navigates the dict side + const side exactly once. Removes the `is_dict_fsl_extension` peek helper (which duplicated the later `as_opt::<Extension>() / <FixedSizeList> / <Dict>` walk) and the three `vortex_expect("peek guaranteed ...")` reassertions. - Drop the stale "product table" / "<= 256" references in the execute orchestration comment, `case2_empty_len_zero` doc, and the inner-loop comment that pointed at a below-function explanation that no longer exists. - Rename the nested test module from `case_1_and_2` to the more descriptive `constant_query_optimizations`. - Simplify `dict_vector_f32` and the `case2_u16_codes_falls_through` fixture to use `Buffer::copy_from` directly instead of building a `BufferMut`, extending, then freezing. Drops the now-unused `BufferMut` import from the test module. - Fix `build_sorf_with_dict_child`'s doc comment (claimed 5-tuple, returns 4-tuple) and remove the dead dim=128 shadow prelude in `case1_sorf_lhs_constant_rhs_padded_gt_dim` that was a leftover from initial test drafting. Signed-off-by: Claude <noreply@anthropic.com> Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent d9a9a1f commit d1a0e6c

1 file changed

Lines changed: 69 additions & 84 deletions

File tree

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 69 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,9 @@ impl ScalarFnVTable for InnerProduct {
184184
return Ok(rewritten);
185185
}
186186

187-
// Reduction case 2: `InnerProduct(Vector[FSL(Dict(u8, f32<=256))], const)` is
188-
// computed by precomputing a product table and gather-summing over the codes.
187+
// Reduction case 2: `InnerProduct(Vector[FSL(Dict(u8, f32))], const)` is computed by
188+
// gather-summing `q[j] * values[codes[j] as usize]` per row, reading the codebook
189+
// directly instead of decoding the column into dense vectors.
189190
if let Some(result) = self.try_execute_dict_constant(&lhs_ref, &rhs_ref, len, ctx)? {
190191
return Ok(result);
191192
}
@@ -330,17 +331,14 @@ impl InnerProduct {
330331
ctx: &mut ExecutionCtx,
331332
) -> VortexResult<Option<ArrayRef>> {
332333
// Identify which side is the SorfTransform, if any.
333-
let (sorf_ref, const_ref) = if lhs_ref.is::<ExactScalarFn<SorfTransform>>() {
334-
(lhs_ref, rhs_ref)
335-
} else if rhs_ref.is::<ExactScalarFn<SorfTransform>>() {
336-
(rhs_ref, lhs_ref)
337-
} else {
338-
return Ok(None);
339-
};
340-
341-
let sorf_view = sorf_ref
342-
.as_opt::<ExactScalarFn<SorfTransform>>()
343-
.vortex_expect("just checked via `is`");
334+
let (sorf_view, const_ref) =
335+
if let Some(view) = lhs_ref.as_opt::<ExactScalarFn<SorfTransform>>() {
336+
(view, rhs_ref)
337+
} else if let Some(view) = rhs_ref.as_opt::<ExactScalarFn<SorfTransform>>() {
338+
(view, lhs_ref)
339+
} else {
340+
return Ok(None);
341+
};
344342

345343
// TODO(connor): pull-through is only sound for F32 because SorfTransform applies an
346344
// `f32 -> element_ptype` cast at the end of its execute. For F16/F64 the rewrite
@@ -443,17 +441,38 @@ impl InnerProduct {
443441
len: usize,
444442
ctx: &mut ExecutionCtx,
445443
) -> VortexResult<Option<ArrayRef>> {
446-
// Identify which side is `Extension[FSL[Dict]]`, with a non-canonicalizing peek.
447-
let (dict_ref, const_ref) = if is_dict_fsl_extension(lhs_ref) {
448-
(lhs_ref, rhs_ref)
449-
} else if is_dict_fsl_extension(rhs_ref) {
450-
(rhs_ref, lhs_ref)
451-
} else {
444+
// Try each orientation. The oriented helper navigates each side exactly once, so
445+
// the only redundant work here is the failed navigation of the first side when the
446+
// dict happens to be on the right.
447+
if let Some(result) = self.try_execute_dict_constant_oriented(lhs_ref, rhs_ref, len, ctx)? {
448+
return Ok(Some(result));
449+
}
450+
self.try_execute_dict_constant_oriented(rhs_ref, lhs_ref, len, ctx)
451+
}
452+
453+
/// Orientation-specific helper for [`Self::try_execute_dict_constant`]. `dict_candidate`
454+
/// is tried as `Extension[FSL[Dict]]`; `const_candidate` is tried as a constant-backed
455+
/// tensor extension. Returns `Ok(None)` if either navigation fails or any gate rejects.
456+
fn try_execute_dict_constant_oriented(
457+
&self,
458+
dict_candidate: &ArrayRef,
459+
const_candidate: &ArrayRef,
460+
len: usize,
461+
ctx: &mut ExecutionCtx,
462+
) -> VortexResult<Option<ArrayRef>> {
463+
// Navigate the dict side.
464+
let Some(dict_ext) = dict_candidate.as_opt::<Extension>() else {
465+
return Ok(None);
466+
};
467+
let Some(fsl) = dict_ext.storage_array().as_opt::<FixedSizeList>() else {
468+
return Ok(None);
469+
};
470+
let Some(dict) = fsl.elements().as_opt::<Dict>() else {
452471
return Ok(None);
453472
};
454473

455-
// Detect the constant-backed extension on the other side.
456-
let Some(const_ext) = const_ref.as_opt::<Extension>() else {
474+
// Navigate the constant side and require its scalar be non-null.
475+
let Some(const_ext) = const_candidate.as_opt::<Extension>() else {
457476
return Ok(None);
458477
};
459478
let const_storage = const_ext.storage_array();
@@ -464,21 +483,8 @@ impl InnerProduct {
464483
return Ok(None);
465484
}
466485

467-
// Navigate into the dict side. We already verified the shape via the peek.
468-
let dict_ext = dict_ref
469-
.as_opt::<Extension>()
470-
.vortex_expect("peek guaranteed Extension");
471-
let fsl = dict_ext
472-
.storage_array()
473-
.as_opt::<FixedSizeList>()
474-
.vortex_expect("peek guaranteed FixedSizeList");
475-
let dict = fsl
476-
.elements()
477-
.as_opt::<Dict>()
478-
.vortex_expect("peek guaranteed Dict");
479-
480486
// Canonicalize codes and values. Codes may be e.g. BitPacked; executing is cheaper
481-
// than falling through to the naive path (which would also canonicalize).
487+
// than falling through to the standard path (which would also canonicalize).
482488
let codes_prim: PrimitiveArray = dict.codes().clone().execute(ctx)?;
483489
let values_prim: PrimitiveArray = dict.values().clone().execute(ctx)?;
484490

@@ -505,9 +511,11 @@ impl InnerProduct {
505511

506512
// Combine the input validities up front; the per-row arithmetic may write garbage
507513
// into null rows but the validity mask hides it (matching the standard path).
508-
let validity = dict_ref.validity()?.and(const_ref.validity()?)?;
514+
let validity = dict_candidate
515+
.validity()?
516+
.and(const_candidate.validity()?)?;
509517

510-
// Fast path for the empty case: skip the product-table allocation entirely.
518+
// Fast path for the empty case: skip allocating and touching the codes buffer.
511519
if len == 0 {
512520
let empty = PrimitiveArray::empty::<f32>(validity.nullability());
513521
return Ok(Some(empty.into_array()));
@@ -519,8 +527,8 @@ impl InnerProduct {
519527
let values: &[f32] = values_prim.as_slice::<f32>();
520528
debug_assert_eq!(codes.len(), len * padded_dim);
521529

522-
// Direct codebook lookup: `acc += q[j] * values[codes[j]]`. See the benchmark
523-
// comment below for why this beats an explicit product table here.
530+
// Direct codebook lookup in the hot loop. See the function doc comment for why this
531+
// beats an explicit product table here.
524532
let mut out = BufferMut::<f32>::with_capacity(len);
525533
for row in 0..len {
526534
let row_codes = &codes[row * padded_dim..(row + 1) * padded_dim];
@@ -538,19 +546,6 @@ impl InnerProduct {
538546
}
539547
}
540548

541-
/// Non-canonicalizing shape peek: returns `true` iff `arr` is an `Extension` whose storage
542-
/// is a `FixedSizeList` whose elements are a `Dict`. Used to pick the dict side of an
543-
/// `InnerProduct` without doing any execution work.
544-
fn is_dict_fsl_extension(arr: &ArrayRef) -> bool {
545-
let Some(ext) = arr.as_opt::<Extension>() else {
546-
return false;
547-
};
548-
let Some(fsl) = ext.storage_array().as_opt::<FixedSizeList>() else {
549-
return false;
550-
};
551-
fsl.elements().as_opt::<Dict>().is_some()
552-
}
553-
554549
/// Computes the inner product (dot product) of two equal-length float slices.
555550
///
556551
/// Returns `sum(a_i * b_i)`.
@@ -781,13 +776,13 @@ mod tests {
781776
Ok(())
782777
}
783778

784-
// ---- Case 1 (SorfTransform + Constant) and Case 2 (Dict + Constant) tests ----
779+
// ---- Tests for the `SorfTransform + constant` and `Dict + constant` fast paths ----
785780

786781
#[allow(
787782
clippy::cast_possible_truncation,
788783
reason = "tests build small fixtures with deterministic in-range indices"
789784
)]
790-
mod case_1_and_2 {
785+
mod constant_query_optimizations {
791786
use std::sync::LazyLock;
792787

793788
use rstest::rstest;
@@ -809,7 +804,6 @@ mod tests {
809804
use vortex_array::session::ArraySession;
810805
use vortex_array::validity::Validity;
811806
use vortex_buffer::Buffer;
812-
use vortex_buffer::BufferMut;
813807
use vortex_error::VortexResult;
814808
use vortex_session::VortexSession;
815809

@@ -852,16 +846,12 @@ mod tests {
852846
/// TurboQuant produces as the SorfTransform child.
853847
fn dict_vector_f32(list_size: u32, codes: &[u8], values: &[f32]) -> VortexResult<ArrayRef> {
854848
let num_rows = codes.len() / list_size as usize;
855-
let codes_arr = {
856-
let mut buf = BufferMut::<u8>::with_capacity(codes.len());
857-
buf.extend_from_slice(codes);
858-
PrimitiveArray::new::<u8>(buf.freeze(), Validity::NonNullable).into_array()
859-
};
860-
let values_arr = {
861-
let mut buf = BufferMut::<f32>::with_capacity(values.len());
862-
buf.extend_from_slice(values);
863-
PrimitiveArray::new::<f32>(buf.freeze(), Validity::NonNullable).into_array()
864-
};
849+
let codes_arr =
850+
PrimitiveArray::new::<u8>(Buffer::copy_from(codes), Validity::NonNullable)
851+
.into_array();
852+
let values_arr =
853+
PrimitiveArray::new::<f32>(Buffer::copy_from(values), Validity::NonNullable)
854+
.into_array();
865855
let dict = DictArray::try_new(codes_arr, values_arr)?;
866856
let fsl = FixedSizeListArray::try_new(
867857
dict.into_array(),
@@ -896,7 +886,7 @@ mod tests {
896886

897887
/// Build a SorfTransform ScalarFnArray whose child is a `Vector<padded_dim, f32>`
898888
/// wrapping `FSL(Dict(codes, values))`. Returns `(sorf_array, codes, values,
899-
/// padded_dim, dim)`.
889+
/// padded_dim)`.
900890
fn build_sorf_with_dict_child(
901891
dim: u32,
902892
num_rows: usize,
@@ -955,19 +945,16 @@ mod tests {
955945

956946
// ---- Case 1: SorfTransform + Constant pull-through ----
957947

958-
/// Case 1: SorfTransform on LHS, constant query on RHS, `dim < padded_dim`.
948+
/// Case 1: SorfTransform on LHS, constant query on RHS, with `dim < padded_dim`
949+
/// so the zero-padding branch is exercised.
959950
#[test]
960951
fn case1_sorf_lhs_constant_rhs_padded_gt_dim() -> VortexResult<()> {
961-
let dim: u32 = 128;
962-
let padded_dim = (dim as usize).next_power_of_two();
963-
assert_eq!(padded_dim, 128); // degenerate case for this dim, handled by next test
964-
// Bump dim to exercise padding explicitly.
965952
let dim: u32 = 100;
966-
let padded_dim = (dim as usize).next_power_of_two();
967-
assert!(padded_dim > dim as usize);
968953
let num_rows = 7usize;
969954
let seed = 42u64;
970955
let num_rounds = 3u8;
956+
let padded_dim = (dim as usize).next_power_of_two();
957+
assert!(padded_dim > dim as usize, "test must exercise padding");
971958

972959
let (sorf_lhs, codes, values, padded_dim_computed) =
973960
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
@@ -1095,7 +1082,7 @@ mod tests {
10951082
Ok(())
10961083
}
10971084

1098-
// ---- Case 2: Dict + Constant product-table path ----
1085+
// ---- Case 2: Dict + Constant direct-lookup path ----
10991086

11001087
/// Case 2: Vector[FSL[Dict(u8, f32)]] on LHS, constant query on RHS.
11011088
#[test]
@@ -1169,18 +1156,16 @@ mod tests {
11691156
let num_values = 300usize;
11701157
let values: Vec<f32> = (0..num_values).map(|i| i as f32 * 0.01).collect();
11711158
// Codes must be u16 because 300 > 255. dict_vector_f32 only supports u8 so we
1172-
// need a u16-coded dict here.
1173-
let mut codes_u16_buf = BufferMut::<u16>::with_capacity(num_rows * 4);
1174-
for i in 0..(num_rows * 4) {
1175-
codes_u16_buf.push((i % num_values) as u16);
1176-
}
1159+
// build the dict by hand here.
1160+
let codes_u16: Vec<u16> = (0..(num_rows * 4))
1161+
.map(|i| (i % num_values) as u16)
1162+
.collect();
11771163
let codes_arr =
1178-
PrimitiveArray::new::<u16>(codes_u16_buf.freeze(), Validity::NonNullable)
1164+
PrimitiveArray::new::<u16>(Buffer::copy_from(codes_u16), Validity::NonNullable)
11791165
.into_array();
1180-
let mut values_buf = BufferMut::<f32>::with_capacity(values.len());
1181-
values_buf.extend_from_slice(&values);
11821166
let values_arr =
1183-
PrimitiveArray::new::<f32>(values_buf.freeze(), Validity::NonNullable).into_array();
1167+
PrimitiveArray::new::<f32>(Buffer::copy_from(&values), Validity::NonNullable)
1168+
.into_array();
11841169
let dict = DictArray::try_new(codes_arr, values_arr)?;
11851170
let fsl = FixedSizeListArray::try_new(
11861171
dict.into_array(),
@@ -1241,7 +1226,7 @@ mod tests {
12411226
}
12421227

12431228
/// Case 2: empty `len == 0` fast path returns an empty primitive array without
1244-
/// allocating a product table.
1229+
/// touching the codes buffer.
12451230
#[test]
12461231
fn case2_empty_len_zero() -> VortexResult<()> {
12471232
let list_size: u32 = 4;

0 commit comments

Comments
 (0)