@@ -184,8 +184,9 @@ impl ScalarFnVTable for InnerProduct {
184184 return Ok ( rewritten) ;
185185 }
186186
187- // Reduction case 2: `InnerProduct(Vector[FSL(Dict(u8, f32<=256))], const)` is
188- // computed by precomputing a product table and gather-summing over the codes.
187+ // Reduction case 2: `InnerProduct(Vector[FSL(Dict(u8, f32))], const)` is computed by
188+ // gather-summing `q[j] * values[codes[j] as usize]` per row, reading the codebook
189+ // directly instead of decoding the column into dense vectors.
189190 if let Some ( result) = self . try_execute_dict_constant ( & lhs_ref, & rhs_ref, len, ctx) ? {
190191 return Ok ( result) ;
191192 }
@@ -330,17 +331,14 @@ impl InnerProduct {
330331 ctx : & mut ExecutionCtx ,
331332 ) -> VortexResult < Option < ArrayRef > > {
332333 // Identify which side is the SorfTransform, if any.
333- let ( sorf_ref, const_ref) = if lhs_ref. is :: < ExactScalarFn < SorfTransform > > ( ) {
334- ( lhs_ref, rhs_ref)
335- } else if rhs_ref. is :: < ExactScalarFn < SorfTransform > > ( ) {
336- ( rhs_ref, lhs_ref)
337- } else {
338- return Ok ( None ) ;
339- } ;
340-
341- let sorf_view = sorf_ref
342- . as_opt :: < ExactScalarFn < SorfTransform > > ( )
343- . vortex_expect ( "just checked via `is`" ) ;
334+ let ( sorf_view, const_ref) =
335+ if let Some ( view) = lhs_ref. as_opt :: < ExactScalarFn < SorfTransform > > ( ) {
336+ ( view, rhs_ref)
337+ } else if let Some ( view) = rhs_ref. as_opt :: < ExactScalarFn < SorfTransform > > ( ) {
338+ ( view, lhs_ref)
339+ } else {
340+ return Ok ( None ) ;
341+ } ;
344342
345343 // TODO(connor): pull-through is only sound for F32 because SorfTransform applies an
346344 // `f32 -> element_ptype` cast at the end of its execute. For F16/F64 the rewrite
@@ -443,17 +441,38 @@ impl InnerProduct {
443441 len : usize ,
444442 ctx : & mut ExecutionCtx ,
445443 ) -> VortexResult < Option < ArrayRef > > {
446- // Identify which side is `Extension[FSL[Dict]]`, with a non-canonicalizing peek.
447- let ( dict_ref, const_ref) = if is_dict_fsl_extension ( lhs_ref) {
448- ( lhs_ref, rhs_ref)
449- } else if is_dict_fsl_extension ( rhs_ref) {
450- ( rhs_ref, lhs_ref)
451- } else {
444+ // Try each orientation. The oriented helper navigates each side exactly once, so
445+ // the only redundant work here is the failed navigation of the first side when the
446+ // dict happens to be on the right.
447+ if let Some ( result) = self . try_execute_dict_constant_oriented ( lhs_ref, rhs_ref, len, ctx) ? {
448+ return Ok ( Some ( result) ) ;
449+ }
450+ self . try_execute_dict_constant_oriented ( rhs_ref, lhs_ref, len, ctx)
451+ }
452+
453+ /// Orientation-specific helper for [`Self::try_execute_dict_constant`]. `dict_candidate`
454+ /// is tried as `Extension[FSL[Dict]]`; `const_candidate` is tried as a constant-backed
455+ /// tensor extension. Returns `Ok(None)` if either navigation fails or any gate rejects.
456+ fn try_execute_dict_constant_oriented (
457+ & self ,
458+ dict_candidate : & ArrayRef ,
459+ const_candidate : & ArrayRef ,
460+ len : usize ,
461+ ctx : & mut ExecutionCtx ,
462+ ) -> VortexResult < Option < ArrayRef > > {
463+ // Navigate the dict side.
464+ let Some ( dict_ext) = dict_candidate. as_opt :: < Extension > ( ) else {
465+ return Ok ( None ) ;
466+ } ;
467+ let Some ( fsl) = dict_ext. storage_array ( ) . as_opt :: < FixedSizeList > ( ) else {
468+ return Ok ( None ) ;
469+ } ;
470+ let Some ( dict) = fsl. elements ( ) . as_opt :: < Dict > ( ) else {
452471 return Ok ( None ) ;
453472 } ;
454473
455- // Detect the constant-backed extension on the other side .
456- let Some ( const_ext) = const_ref . as_opt :: < Extension > ( ) else {
474+ // Navigate the constant side and require its scalar be non-null .
475+ let Some ( const_ext) = const_candidate . as_opt :: < Extension > ( ) else {
457476 return Ok ( None ) ;
458477 } ;
459478 let const_storage = const_ext. storage_array ( ) ;
@@ -464,21 +483,8 @@ impl InnerProduct {
464483 return Ok ( None ) ;
465484 }
466485
467- // Navigate into the dict side. We already verified the shape via the peek.
468- let dict_ext = dict_ref
469- . as_opt :: < Extension > ( )
470- . vortex_expect ( "peek guaranteed Extension" ) ;
471- let fsl = dict_ext
472- . storage_array ( )
473- . as_opt :: < FixedSizeList > ( )
474- . vortex_expect ( "peek guaranteed FixedSizeList" ) ;
475- let dict = fsl
476- . elements ( )
477- . as_opt :: < Dict > ( )
478- . vortex_expect ( "peek guaranteed Dict" ) ;
479-
480486 // Canonicalize codes and values. Codes may be e.g. BitPacked; executing is cheaper
481- // than falling through to the naive path (which would also canonicalize).
487+ // than falling through to the standard path (which would also canonicalize).
482488 let codes_prim: PrimitiveArray = dict. codes ( ) . clone ( ) . execute ( ctx) ?;
483489 let values_prim: PrimitiveArray = dict. values ( ) . clone ( ) . execute ( ctx) ?;
484490
@@ -505,9 +511,11 @@ impl InnerProduct {
505511
506512 // Combine the input validities up front; the per-row arithmetic may write garbage
507513 // into null rows but the validity mask hides it (matching the standard path).
508- let validity = dict_ref. validity ( ) ?. and ( const_ref. validity ( ) ?) ?;
514+ let validity = dict_candidate
515+ . validity ( ) ?
516+ . and ( const_candidate. validity ( ) ?) ?;
509517
510- // Fast path for the empty case: skip the product-table allocation entirely .
518+ // Fast path for the empty case: skip allocating and touching the codes buffer .
511519 if len == 0 {
512520 let empty = PrimitiveArray :: empty :: < f32 > ( validity. nullability ( ) ) ;
513521 return Ok ( Some ( empty. into_array ( ) ) ) ;
@@ -519,8 +527,8 @@ impl InnerProduct {
519527 let values: & [ f32 ] = values_prim. as_slice :: < f32 > ( ) ;
520528 debug_assert_eq ! ( codes. len( ) , len * padded_dim) ;
521529
522- // Direct codebook lookup: `acc += q[j] * values[codes[j]]` . See the benchmark
523- // comment below for why this beats an explicit product table here.
530+ // Direct codebook lookup in the hot loop . See the function doc comment for why this
531+ // beats an explicit product table here.
524532 let mut out = BufferMut :: < f32 > :: with_capacity ( len) ;
525533 for row in 0 ..len {
526534 let row_codes = & codes[ row * padded_dim..( row + 1 ) * padded_dim] ;
@@ -538,19 +546,6 @@ impl InnerProduct {
538546 }
539547}
540548
541- /// Non-canonicalizing shape peek: returns `true` iff `arr` is an `Extension` whose storage
542- /// is a `FixedSizeList` whose elements are a `Dict`. Used to pick the dict side of an
543- /// `InnerProduct` without doing any execution work.
544- fn is_dict_fsl_extension ( arr : & ArrayRef ) -> bool {
545- let Some ( ext) = arr. as_opt :: < Extension > ( ) else {
546- return false ;
547- } ;
548- let Some ( fsl) = ext. storage_array ( ) . as_opt :: < FixedSizeList > ( ) else {
549- return false ;
550- } ;
551- fsl. elements ( ) . as_opt :: < Dict > ( ) . is_some ( )
552- }
553-
554549/// Computes the inner product (dot product) of two equal-length float slices.
555550///
556551/// Returns `sum(a_i * b_i)`.
@@ -781,13 +776,13 @@ mod tests {
781776 Ok ( ( ) )
782777 }
783778
784- // ---- Case 1 ( SorfTransform + Constant) and Case 2 ( Dict + Constant) tests ----
779+ // ---- Tests for the ` SorfTransform + constant` and ` Dict + constant` fast paths ----
785780
786781 #[ allow(
787782 clippy:: cast_possible_truncation,
788783 reason = "tests build small fixtures with deterministic in-range indices"
789784 ) ]
790- mod case_1_and_2 {
785+ mod constant_query_optimizations {
791786 use std:: sync:: LazyLock ;
792787
793788 use rstest:: rstest;
@@ -809,7 +804,6 @@ mod tests {
809804 use vortex_array:: session:: ArraySession ;
810805 use vortex_array:: validity:: Validity ;
811806 use vortex_buffer:: Buffer ;
812- use vortex_buffer:: BufferMut ;
813807 use vortex_error:: VortexResult ;
814808 use vortex_session:: VortexSession ;
815809
@@ -852,16 +846,12 @@ mod tests {
852846 /// TurboQuant produces as the SorfTransform child.
853847 fn dict_vector_f32 ( list_size : u32 , codes : & [ u8 ] , values : & [ f32 ] ) -> VortexResult < ArrayRef > {
854848 let num_rows = codes. len ( ) / list_size as usize ;
855- let codes_arr = {
856- let mut buf = BufferMut :: < u8 > :: with_capacity ( codes. len ( ) ) ;
857- buf. extend_from_slice ( codes) ;
858- PrimitiveArray :: new :: < u8 > ( buf. freeze ( ) , Validity :: NonNullable ) . into_array ( )
859- } ;
860- let values_arr = {
861- let mut buf = BufferMut :: < f32 > :: with_capacity ( values. len ( ) ) ;
862- buf. extend_from_slice ( values) ;
863- PrimitiveArray :: new :: < f32 > ( buf. freeze ( ) , Validity :: NonNullable ) . into_array ( )
864- } ;
849+ let codes_arr =
850+ PrimitiveArray :: new :: < u8 > ( Buffer :: copy_from ( codes) , Validity :: NonNullable )
851+ . into_array ( ) ;
852+ let values_arr =
853+ PrimitiveArray :: new :: < f32 > ( Buffer :: copy_from ( values) , Validity :: NonNullable )
854+ . into_array ( ) ;
865855 let dict = DictArray :: try_new ( codes_arr, values_arr) ?;
866856 let fsl = FixedSizeListArray :: try_new (
867857 dict. into_array ( ) ,
@@ -896,7 +886,7 @@ mod tests {
896886
897887 /// Build a SorfTransform ScalarFnArray whose child is a `Vector<padded_dim, f32>`
898888 /// wrapping `FSL(Dict(codes, values))`. Returns `(sorf_array, codes, values,
899- /// padded_dim, dim )`.
889+ /// padded_dim)`.
900890 fn build_sorf_with_dict_child (
901891 dim : u32 ,
902892 num_rows : usize ,
@@ -955,19 +945,16 @@ mod tests {
955945
956946 // ---- Case 1: SorfTransform + Constant pull-through ----
957947
958- /// Case 1: SorfTransform on LHS, constant query on RHS, `dim < padded_dim`.
948+ /// Case 1: SorfTransform on LHS, constant query on RHS, with `dim < padded_dim`
949+ /// so the zero-padding branch is exercised.
959950 #[ test]
960951 fn case1_sorf_lhs_constant_rhs_padded_gt_dim ( ) -> VortexResult < ( ) > {
961- let dim: u32 = 128 ;
962- let padded_dim = ( dim as usize ) . next_power_of_two ( ) ;
963- assert_eq ! ( padded_dim, 128 ) ; // degenerate case for this dim, handled by next test
964- // Bump dim to exercise padding explicitly.
965952 let dim: u32 = 100 ;
966- let padded_dim = ( dim as usize ) . next_power_of_two ( ) ;
967- assert ! ( padded_dim > dim as usize ) ;
968953 let num_rows = 7usize ;
969954 let seed = 42u64 ;
970955 let num_rounds = 3u8 ;
956+ let padded_dim = ( dim as usize ) . next_power_of_two ( ) ;
957+ assert ! ( padded_dim > dim as usize , "test must exercise padding" ) ;
971958
972959 let ( sorf_lhs, codes, values, padded_dim_computed) =
973960 build_sorf_with_dict_child ( dim, num_rows, seed, num_rounds) ?;
@@ -1095,7 +1082,7 @@ mod tests {
10951082 Ok ( ( ) )
10961083 }
10971084
1098- // ---- Case 2: Dict + Constant product-table path ----
1085+ // ---- Case 2: Dict + Constant direct-lookup path ----
10991086
11001087 /// Case 2: Vector[FSL[Dict(u8, f32)]] on LHS, constant query on RHS.
11011088 #[ test]
@@ -1169,18 +1156,16 @@ mod tests {
11691156 let num_values = 300usize ;
11701157 let values: Vec < f32 > = ( 0 ..num_values) . map ( |i| i as f32 * 0.01 ) . collect ( ) ;
11711158 // Codes must be u16 because 300 > 255. dict_vector_f32 only supports u8 so we
1172- // need a u16-coded dict here.
1173- let mut codes_u16_buf = BufferMut :: < u16 > :: with_capacity ( num_rows * 4 ) ;
1174- for i in 0 ..( num_rows * 4 ) {
1175- codes_u16_buf. push ( ( i % num_values) as u16 ) ;
1176- }
1159+ // build the dict by hand here.
1160+ let codes_u16: Vec < u16 > = ( 0 ..( num_rows * 4 ) )
1161+ . map ( |i| ( i % num_values) as u16 )
1162+ . collect ( ) ;
11771163 let codes_arr =
1178- PrimitiveArray :: new :: < u16 > ( codes_u16_buf . freeze ( ) , Validity :: NonNullable )
1164+ PrimitiveArray :: new :: < u16 > ( Buffer :: copy_from ( codes_u16 ) , Validity :: NonNullable )
11791165 . into_array ( ) ;
1180- let mut values_buf = BufferMut :: < f32 > :: with_capacity ( values. len ( ) ) ;
1181- values_buf. extend_from_slice ( & values) ;
11821166 let values_arr =
1183- PrimitiveArray :: new :: < f32 > ( values_buf. freeze ( ) , Validity :: NonNullable ) . into_array ( ) ;
1167+ PrimitiveArray :: new :: < f32 > ( Buffer :: copy_from ( & values) , Validity :: NonNullable )
1168+ . into_array ( ) ;
11841169 let dict = DictArray :: try_new ( codes_arr, values_arr) ?;
11851170 let fsl = FixedSizeListArray :: try_new (
11861171 dict. into_array ( ) ,
@@ -1241,7 +1226,7 @@ mod tests {
12411226 }
12421227
12431228 /// Case 2: empty `len == 0` fast path returns an empty primitive array without
1244- /// allocating a product table .
1229+ /// touching the codes buffer .
12451230 #[ test]
12461231 fn case2_empty_len_zero ( ) -> VortexResult < ( ) > {
12471232 let list_size: u32 = 4 ;
0 commit comments