@@ -46,7 +46,6 @@ use vortex_error::VortexResult;
4646use vortex_error:: vortex_ensure;
4747use vortex_error:: vortex_err;
4848
49- use crate :: encodings:: turboquant:: MAX_CENTROIDS ;
5049use crate :: matcher:: AnyTensor ;
5150use crate :: scalar_fns:: l2_denorm:: L2Denorm ;
5251use crate :: scalar_fns:: sorf_transform:: SorfMatrix ;
@@ -424,18 +423,16 @@ impl InnerProduct {
424423 Ok ( Some ( rewritten) )
425424 }
426425
427- /// Fast path when one side is an extension whose storage is `FSL(Dict(u8, f32))` with
428- /// at most `MAX_CENTROIDS` values, and the other side is a constant-backed tensor
429- /// extension with an F32 element ptype.
426+ /// Fast path when one side is an extension whose storage is `FSL(Dict(u8, f32))` and
427+ /// the other side is a constant-backed tensor extension with an F32 element ptype.
430428 ///
431429 /// Computes each row's inner product as
432430 /// `out[i] = sum_{j in 0..padded_dim} q[j] * values[codes[i * padded_dim + j] as usize]`
433431 /// using a direct codebook lookup in the hot loop. An explicit product table
434432 /// `P[j, k] = q[j] * values[k]` (size `padded_dim * num_centroids * 4B`, ~1 MiB for the
435433 /// common 1024/256 case) was tried and measured ~10% *slower* on the
436434 /// `similarity_search` bench because the 1 KiB `values` table stays in L1 across all
437- /// rows, while the 1 MiB product table does not. See the plan file for the math
438- /// justifying both forms.
435+ /// rows, while the 1 MiB product table does not.
439436 ///
440437 /// Returns `Ok(None)` when the pattern doesn't match; the caller should fall through to
441438 /// the standard path.
@@ -485,20 +482,17 @@ impl InnerProduct {
485482 let codes_prim: PrimitiveArray = dict. codes ( ) . clone ( ) . execute ( ctx) ?;
486483 let values_prim: PrimitiveArray = dict. values ( ) . clone ( ) . execute ( ctx) ?;
487484
488- // Gate: u8 codes, f32 centroids, and at most 256 centroids.
485+ // Gate: u8 codes and f32 centroids.
489486 if codes_prim. ptype ( ) != PType :: U8 {
490487 // TODO(connor): support wider code widths (u16, u32). TurboQuant only emits u8
491488 // codes today, so this is the only path we need for now.
492489 return Ok ( None ) ;
493490 }
494491 if values_prim. ptype ( ) != PType :: F32 {
495- // TODO(connor): the product-table path only supports f32 centroids. SorfTransform
492+ // TODO(connor): direct-lookup path only supports f32 centroids. SorfTransform
496493 // forces f32 anyway, so this is the only shape we need for now.
497494 return Ok ( None ) ;
498495 }
499- if values_prim. len ( ) > MAX_CENTROIDS {
500- return Ok ( None ) ;
501- }
502496
503497 let padded_dim = usize:: try_from ( fsl. list_size ( ) ) . vortex_expect ( "fsl list_size fits usize" ) ;
504498
@@ -796,6 +790,7 @@ mod tests {
796790 mod case_1_and_2 {
797791 use std:: sync:: LazyLock ;
798792
793+ use rstest:: rstest;
799794 use vortex_array:: ArrayRef ;
800795 use vortex_array:: IntoArray ;
801796 use vortex_array:: VortexSessionExecute ;
@@ -1164,10 +1159,11 @@ mod tests {
11641159 Ok ( ( ) )
11651160 }
11661161
1167- /// Case 2: dict with more than 256 values falls through to the standard path but
1168- /// still produces the correct result.
1162+ /// Case 2: dict with `u16` codes (and hence more than 256 values) falls through to
1163+ /// the standard path but still produces the correct result. The direct-lookup path
1164+ /// only handles `u8` codes today.
11691165 #[ test]
1170- fn case2_more_than_256_values_falls_through ( ) -> VortexResult < ( ) > {
1166+ fn case2_u16_codes_falls_through ( ) -> VortexResult < ( ) > {
11711167 let list_size: u32 = 4 ;
11721168 let num_rows = 3usize ;
11731169 let num_values = 300usize ;
@@ -1300,5 +1296,296 @@ mod tests {
13001296 assert_close_f32 ( & actual, & expected, 1e-3 ) ;
13011297 Ok ( ( ) )
13021298 }
1299+
1300+ // ---- Additional correctness / stress tests (all with loose tolerances) ----
1301+
1302+ /// A tiny in-place xorshift64 PRNG so these tests don't depend on `rand`. Producing
1303+ /// deterministic pseudo-random f32 values lets the correctness checks exercise
1304+ /// realistic data instead of smooth sin/cos patterns.
1305+ struct XorShift64 ( u64 ) ;
1306+
1307+ impl XorShift64 {
1308+ fn new ( seed : u64 ) -> Self {
1309+ // Any nonzero seed is fine; xorshift fixed-points at 0.
1310+ Self ( seed. wrapping_add ( 0x9E37_79B9_7F4A_7C15 ) )
1311+ }
1312+
1313+ fn next_u64 ( & mut self ) -> u64 {
1314+ let mut x = self . 0 ;
1315+ x ^= x << 13 ;
1316+ x ^= x >> 7 ;
1317+ x ^= x << 17 ;
1318+ self . 0 = x;
1319+ x
1320+ }
1321+
1322+ /// Uniform f32 in `[-1.0, 1.0)`.
1323+ fn next_f32 ( & mut self ) -> f32 {
1324+ // Top 24 bits -> mantissa in [0, 1), then shift to [-1, 1).
1325+ let bits = ( self . next_u64 ( ) >> 40 ) as u32 ; // 24 bits
1326+ ( bits as f32 ) / ( 1u32 << 24 ) as f32 * 2.0 - 1.0
1327+ }
1328+ }
1329+
1330+ /// Case 2 stress: u8-coded dict with 200 centroids (formerly blocked by the
1331+ /// `values.len() <= 256` gate). The direct-lookup path must now handle it.
1332+ #[ test]
1333+ fn case2_large_u8_codebook_direct_lookup ( ) -> VortexResult < ( ) > {
1334+ let list_size: u32 = 16 ;
1335+ let num_rows = 20usize ;
1336+ let num_centroids = 200usize ;
1337+ assert ! ( num_centroids > 8 && num_centroids <= 256 ) ;
1338+
1339+ let mut rng = XorShift64 :: new ( 0xDEAD_BEEF ) ;
1340+ let values: Vec < f32 > = ( 0 ..num_centroids) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1341+ let codes: Vec < u8 > = ( 0 ..num_rows * list_size as usize )
1342+ . map ( |_| ( rng. next_u64 ( ) % num_centroids as u64 ) as u8 )
1343+ . collect ( ) ;
1344+
1345+ let dict_lhs = dict_vector_f32 ( list_size, & codes, & values) ?;
1346+ let query: Vec < f32 > = ( 0 ..list_size) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1347+ let const_rhs = constant_vector_f32 ( & query, num_rows) ?;
1348+
1349+ let expected: Vec < f32 > = ( 0 ..num_rows)
1350+ . map ( |row| {
1351+ let mut acc = 0.0f32 ;
1352+ for j in 0 ..list_size as usize {
1353+ let k = codes[ row * list_size as usize + j] as usize ;
1354+ acc += query[ j] * values[ k] ;
1355+ }
1356+ acc
1357+ } )
1358+ . collect ( ) ;
1359+
1360+ let actual = eval_ip_f32 ( dict_lhs, const_rhs, num_rows) ?;
1361+ assert_close_f32 ( & actual, & expected, 1e-4 ) ;
1362+ Ok ( ( ) )
1363+ }
1364+
1365+ /// Parameterized sweep over the full `InnerProduct(SorfTransform(Vector[FSL(Dict)]),
1366+ /// ConstantArray)` tree, exercising the case 1 + case 2 chain for a realistic mix
1367+ /// of dimensions, row counts, seeds, and number of SORF rounds. Tolerance is
1368+ /// deliberately loose because the rewrite introduces an f32-domain rotation that
1369+ /// accumulates a small numerical drift versus a naive decode.
1370+ #[ rstest]
1371+ #[ case:: small_no_pad( 128 , 11 , 1 , 1 ) ]
1372+ #[ case:: small_no_pad_rounds3( 128 , 23 , 1_234 , 3 ) ]
1373+ #[ case:: small_padded( 100 , 17 , 42 , 3 ) ]
1374+ #[ case:: mid_padded( 200 , 13 , 2024 , 3 ) ]
1375+ #[ case:: mid_power_of_two( 256 , 31 , 7 , 3 ) ]
1376+ #[ case:: larger_padded( 300 , 9 , 99 , 3 ) ]
1377+ #[ case:: max_rounds( 128 , 5 , 31_415 , 5 ) ]
1378+ fn case1_sorf_random_sweep (
1379+ #[ case] dim : u32 ,
1380+ #[ case] num_rows : usize ,
1381+ #[ case] seed : u64 ,
1382+ #[ case] num_rounds : u8 ,
1383+ ) -> VortexResult < ( ) > {
1384+ let ( sorf, codes, values, padded_dim) =
1385+ build_sorf_with_dict_child ( dim, num_rows, seed, num_rounds) ?;
1386+
1387+ // Use a pseudo-random query with both positive and negative entries so the sum
1388+ // has cancellation.
1389+ let mut rng = XorShift64 :: new ( seed ^ 0xABCD_1234 ) ;
1390+ let query: Vec < f32 > = ( 0 ..dim) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1391+ let const_rhs = constant_vector_f32 ( & query, num_rows) ?;
1392+
1393+ let decoded = decode_sorf_dict (
1394+ & codes,
1395+ & values,
1396+ padded_dim,
1397+ dim as usize ,
1398+ num_rows,
1399+ seed,
1400+ num_rounds,
1401+ ) ?;
1402+ let expected: Vec < f32 > = ( 0 ..num_rows)
1403+ . map ( |i| naive_dot ( & decoded[ i * dim as usize ..( i + 1 ) * dim as usize ] , & query) )
1404+ . collect ( ) ;
1405+
1406+ // Loose tolerance: the sorf transform works in f32 with a k-round butterfly, so
1407+ // the rewrite path and the decoded path accumulate slightly different rounding
1408+ // even though the math is equivalent in exact arithmetic.
1409+ let actual = eval_ip_f32 ( sorf, const_rhs, num_rows) ?;
1410+ assert_close_f32 ( & actual, & expected, 1e-2 ) ;
1411+ Ok ( ( ) )
1412+ }
1413+
1414+ /// Parameterized sweep over plain `Vector[FSL(Dict(u8, f32))]` + constant query,
1415+ /// without SorfTransform in the mix. This directly exercises case 2 across a
1416+ /// variety of list sizes, num_rows, and codebook sizes including large ones that
1417+ /// the old `<= 256` gate would have rejected.
1418+ #[ rstest]
1419+ #[ case:: small( 4 , 7 , 8 ) ]
1420+ #[ case:: medium( 16 , 50 , 64 ) ]
1421+ #[ case:: larger( 32 , 100 , 150 ) ]
1422+ #[ case:: very_large_codebook( 8 , 25 , 250 ) ]
1423+ fn case2_random_sweep (
1424+ #[ case] list_size : u32 ,
1425+ #[ case] num_rows : usize ,
1426+ #[ case] num_centroids : usize ,
1427+ ) -> VortexResult < ( ) > {
1428+ let mut rng = XorShift64 :: new ( ( list_size as u64 ) * 31 + num_rows as u64 ) ;
1429+ let values: Vec < f32 > = ( 0 ..num_centroids) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1430+ assert ! ( num_centroids <= 256 , "u8 codes cap at 256 centroids" ) ;
1431+ let codes: Vec < u8 > = ( 0 ..num_rows * list_size as usize )
1432+ . map ( |_| ( rng. next_u64 ( ) % num_centroids as u64 ) as u8 )
1433+ . collect ( ) ;
1434+
1435+ let dict_lhs = dict_vector_f32 ( list_size, & codes, & values) ?;
1436+ let query: Vec < f32 > = ( 0 ..list_size) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1437+ let const_rhs = constant_vector_f32 ( & query, num_rows) ?;
1438+
1439+ let expected: Vec < f32 > = ( 0 ..num_rows)
1440+ . map ( |row| {
1441+ let mut acc = 0.0f32 ;
1442+ for j in 0 ..list_size as usize {
1443+ let k = codes[ row * list_size as usize + j] as usize ;
1444+ acc += query[ j] * values[ k] ;
1445+ }
1446+ acc
1447+ } )
1448+ . collect ( ) ;
1449+
1450+ // Tight tolerance here because no SorfTransform rotation is involved — the
1451+ // arithmetic should agree bit-for-bit up to float reassociation.
1452+ let actual = eval_ip_f32 ( dict_lhs, const_rhs, num_rows) ?;
1453+ assert_close_f32 ( & actual, & expected, 1e-4 ) ;
1454+ Ok ( ( ) )
1455+ }
1456+
1457+ /// End-to-end regression: for a plausible vector-search configuration (SORF rounds
1458+ /// = 3, dim = 128, num_rows = 64, u8 codes, 64 centroids), the fast-path result
1459+ /// must track a fully naive computation within 1e-2.
1460+ #[ test]
1461+ fn end_to_end_dim128_rows64_bit6_regression ( ) -> VortexResult < ( ) > {
1462+ let dim: u32 = 128 ;
1463+ let num_rows = 64usize ;
1464+ let seed = 0xFACE_F00D ;
1465+ let num_rounds = 3u8 ;
1466+
1467+ // Use 64 centroids (6 bits), a typical TurboQuant configuration.
1468+ let num_centroids = 64usize ;
1469+ let padded_dim = ( dim as usize ) . next_power_of_two ( ) ;
1470+ let mut rng = XorShift64 :: new ( seed) ;
1471+ let values: Vec < f32 > = ( 0 ..num_centroids) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1472+ let codes: Vec < u8 > = ( 0 ..num_rows * padded_dim)
1473+ . map ( |_| ( rng. next_u64 ( ) % num_centroids as u64 ) as u8 )
1474+ . collect ( ) ;
1475+
1476+ let padded_vector = dict_vector_f32 ( padded_dim as u32 , & codes, & values) ?;
1477+ let sorf_options = SorfOptions {
1478+ seed,
1479+ num_rounds,
1480+ dimension : dim,
1481+ element_ptype : PType :: F32 ,
1482+ } ;
1483+ let sorf =
1484+ SorfTransform :: try_new_array ( & sorf_options, padded_vector, num_rows) ?. into_array ( ) ;
1485+
1486+ let query: Vec < f32 > = ( 0 ..dim) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1487+ let const_rhs = constant_vector_f32 ( & query, num_rows) ?;
1488+
1489+ let decoded = decode_sorf_dict (
1490+ & codes,
1491+ & values,
1492+ padded_dim,
1493+ dim as usize ,
1494+ num_rows,
1495+ seed,
1496+ num_rounds,
1497+ ) ?;
1498+ let expected: Vec < f32 > = ( 0 ..num_rows)
1499+ . map ( |i| naive_dot ( & decoded[ i * dim as usize ..( i + 1 ) * dim as usize ] , & query) )
1500+ . collect ( ) ;
1501+
1502+ let actual = eval_ip_f32 ( sorf, const_rhs, num_rows) ?;
1503+ assert_close_f32 ( & actual, & expected, 1e-2 ) ;
1504+
1505+ // Also verify the max relative error is small. The SORF rotation does not
1506+ // amplify error, so both measures should be bounded.
1507+ for ( i, ( a, e) ) in actual. iter ( ) . zip ( expected. iter ( ) ) . enumerate ( ) {
1508+ let denom = e. abs ( ) . max ( 1.0 ) ;
1509+ let rel = ( a - e) . abs ( ) / denom;
1510+ assert ! (
1511+ rel < 1e-3 ,
1512+ "row {i}: rel err {rel} too large (a={a}, e={e})"
1513+ ) ;
1514+ }
1515+ Ok ( ( ) )
1516+ }
1517+
1518+ /// Case 1 + Case 2 end-to-end with varying `num_rounds`. The rotation becomes
1519+ /// progressively more chaotic as rounds increase, so this catches any off-by-one
1520+ /// bug in the round-indexing that would not show up in the 3-round default.
1521+ #[ rstest]
1522+ #[ case( 1 ) ]
1523+ #[ case( 2 ) ]
1524+ #[ case( 3 ) ]
1525+ #[ case( 4 ) ]
1526+ #[ case( 5 ) ]
1527+ fn case1_various_num_rounds ( #[ case] num_rounds : u8 ) -> VortexResult < ( ) > {
1528+ let dim: u32 = 128 ;
1529+ let num_rows = 8usize ;
1530+ let seed = 0x1234_5678 ;
1531+
1532+ let ( sorf, codes, values, padded_dim) =
1533+ build_sorf_with_dict_child ( dim, num_rows, seed, num_rounds) ?;
1534+
1535+ let mut rng = XorShift64 :: new ( seed ^ ( num_rounds as u64 ) ) ;
1536+ let query: Vec < f32 > = ( 0 ..dim) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1537+ let const_rhs = constant_vector_f32 ( & query, num_rows) ?;
1538+
1539+ let decoded = decode_sorf_dict (
1540+ & codes,
1541+ & values,
1542+ padded_dim,
1543+ dim as usize ,
1544+ num_rows,
1545+ seed,
1546+ num_rounds,
1547+ ) ?;
1548+ let expected: Vec < f32 > = ( 0 ..num_rows)
1549+ . map ( |i| naive_dot ( & decoded[ i * dim as usize ..( i + 1 ) * dim as usize ] , & query) )
1550+ . collect ( ) ;
1551+
1552+ let actual = eval_ip_f32 ( sorf, const_rhs, num_rows) ?;
1553+ assert_close_f32 ( & actual, & expected, 1e-2 ) ;
1554+ Ok ( ( ) )
1555+ }
1556+
1557+ /// Swap LHS and RHS on the full tree to prove the side-detection and the scalar
1558+ /// argument-order handling are symmetric for both cases simultaneously.
1559+ #[ test]
1560+ fn end_to_end_constant_lhs_sorf_rhs_mirrored ( ) -> VortexResult < ( ) > {
1561+ let dim: u32 = 256 ;
1562+ let num_rows = 12usize ;
1563+ let seed = 0xBEEF_CAFE ;
1564+ let num_rounds = 3u8 ;
1565+
1566+ let ( sorf, codes, values, padded_dim) =
1567+ build_sorf_with_dict_child ( dim, num_rows, seed, num_rounds) ?;
1568+
1569+ let mut rng = XorShift64 :: new ( seed) ;
1570+ let query: Vec < f32 > = ( 0 ..dim) . map ( |_| rng. next_f32 ( ) ) . collect ( ) ;
1571+ let const_lhs = constant_vector_f32 ( & query, num_rows) ?;
1572+
1573+ let decoded = decode_sorf_dict (
1574+ & codes,
1575+ & values,
1576+ padded_dim,
1577+ dim as usize ,
1578+ num_rows,
1579+ seed,
1580+ num_rounds,
1581+ ) ?;
1582+ let expected: Vec < f32 > = ( 0 ..num_rows)
1583+ . map ( |i| naive_dot ( & decoded[ i * dim as usize ..( i + 1 ) * dim as usize ] , & query) )
1584+ . collect ( ) ;
1585+
1586+ let actual = eval_ip_f32 ( const_lhs, sorf, num_rows) ?;
1587+ assert_close_f32 ( & actual, & expected, 1e-2 ) ;
1588+ Ok ( ( ) )
1589+ }
13031590 }
13041591}
0 commit comments