Skip to content

Commit 0b8a8a5

Browse files
claudeconnortsui20
authored andcommitted
drop dict codebook-size gate, widen bench, add correctness stress tests
The `values.len() <= MAX_CENTROIDS` gate in `try_execute_dict_constant` was a constraint on the original product-table implementation (1 MiB max for 256 centroids). Since the hot loop now does direct codebook lookups instead of a product table, the gate is no longer needed and gets removed so any u8-coded dict hits the fast path regardless of the actual number of distinct centroids (the u8 code type itself already caps at 256 entries anyway). Also bumps the `similarity_search` bench dataset from 10k to 100k rows to exercise the optimization at a more realistic scale. Results on the new dataset (median): execute_turboquant : 92.3 ms execute_uncompressed : 174.0 ms TurboQuant is now ~47% faster than uncompressed at 100k rows, up from ~15% at 10k rows — the optimization amortizes more as the column grows. `execute_default_compression` allocates the raw dataset (~300 MB) which can OOM a small sandbox; it is unchanged and runs fine on normal dev machines. Adds 19 new parameterized correctness tests (29 total in `case_1_and_2`): - `case2_large_u8_codebook_direct_lookup`: 200-centroid u8 dict, formerly blocked by the removed gate. - `case1_sorf_random_sweep` (7 rstest cases): sweeps dim, num_rows, seed, num_rounds with pseudo-random queries; verifies fast-path inner products match a naive decode within 1e-2. - `case2_random_sweep` (4 rstest cases): plain dict + constant sweeps over list_size, num_rows, and codebook size (including 250 centroids). - `case1_various_num_rounds` (5 rstest cases): 1..=5 SORF rounds to catch any round-indexing off-by-one. - `end_to_end_dim128_rows64_bit6_regression`: plausible vector-search config with 64 centroids (6 bits), verifies both absolute and relative error are small. - `end_to_end_constant_lhs_sorf_rhs_mirrored`: symmetry check across both cases simultaneously. Thresholds are 1e-4 for pure direct-lookup paths (no SorfTransform) and 1e-2 for full SORF end-to-end paths (the f32 butterfly rotation accumulates a small drift between the fast path and a naive decode). Renames `case2_more_than_256_values_falls_through` to `case2_u16_codes_falls_through` to reflect that the fall-through reason is now the u8-code gate, not the removed length gate. Signed-off-by: Claude <noreply@anthropic.com> Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent cf00a48 commit 0b8a8a5

2 files changed

Lines changed: 302 additions & 15 deletions

File tree

vortex-tensor/benches/similarity_search.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use common::generate_random_vectors;
3333
static GLOBAL: MiMalloc = MiMalloc;
3434

3535
/// Number of vectors in the benchmark dataset.
36-
const NUM_ROWS: usize = 10_000;
36+
const NUM_ROWS: usize = 100_000;
3737

3838
/// Dimensionality of each vector. Must be `>= vortex_tensor::encodings::turboquant::MIN_DIMENSION`
3939
/// (128) for the TurboQuant variant to work.

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 301 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ use vortex_error::VortexResult;
4646
use vortex_error::vortex_ensure;
4747
use vortex_error::vortex_err;
4848

49-
use crate::encodings::turboquant::MAX_CENTROIDS;
5049
use crate::matcher::AnyTensor;
5150
use crate::scalar_fns::l2_denorm::L2Denorm;
5251
use crate::scalar_fns::sorf_transform::SorfMatrix;
@@ -424,18 +423,16 @@ impl InnerProduct {
424423
Ok(Some(rewritten))
425424
}
426425

427-
/// Fast path when one side is an extension whose storage is `FSL(Dict(u8, f32))` with
428-
/// at most `MAX_CENTROIDS` values, and the other side is a constant-backed tensor
429-
/// extension with an F32 element ptype.
426+
/// Fast path when one side is an extension whose storage is `FSL(Dict(u8, f32))` and
427+
/// the other side is a constant-backed tensor extension with an F32 element ptype.
430428
///
431429
/// Computes each row's inner product as
432430
/// `out[i] = sum_{j in 0..padded_dim} q[j] * values[codes[i * padded_dim + j] as usize]`
433431
/// using a direct codebook lookup in the hot loop. An explicit product table
434432
/// `P[j, k] = q[j] * values[k]` (size `padded_dim * num_centroids * 4B`, ~1 MiB for the
435433
/// common 1024/256 case) was tried and measured ~10% *slower* on the
436434
/// `similarity_search` bench because the 1 KiB `values` table stays in L1 across all
437-
/// rows, while the 1 MiB product table does not. See the plan file for the math
438-
/// justifying both forms.
435+
/// rows, while the 1 MiB product table does not.
439436
///
440437
/// Returns `Ok(None)` when the pattern doesn't match; the caller should fall through to
441438
/// the standard path.
@@ -485,20 +482,17 @@ impl InnerProduct {
485482
let codes_prim: PrimitiveArray = dict.codes().clone().execute(ctx)?;
486483
let values_prim: PrimitiveArray = dict.values().clone().execute(ctx)?;
487484

488-
// Gate: u8 codes, f32 centroids, and at most 256 centroids.
485+
// Gate: u8 codes and f32 centroids.
489486
if codes_prim.ptype() != PType::U8 {
490487
// TODO(connor): support wider code widths (u16, u32). TurboQuant only emits u8
491488
// codes today, so this is the only path we need for now.
492489
return Ok(None);
493490
}
494491
if values_prim.ptype() != PType::F32 {
495-
// TODO(connor): the product-table path only supports f32 centroids. SorfTransform
492+
// TODO(connor): direct-lookup path only supports f32 centroids. SorfTransform
496493
// forces f32 anyway, so this is the only shape we need for now.
497494
return Ok(None);
498495
}
499-
if values_prim.len() > MAX_CENTROIDS {
500-
return Ok(None);
501-
}
502496

503497
let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize");
504498

@@ -796,6 +790,7 @@ mod tests {
796790
mod case_1_and_2 {
797791
use std::sync::LazyLock;
798792

793+
use rstest::rstest;
799794
use vortex_array::ArrayRef;
800795
use vortex_array::IntoArray;
801796
use vortex_array::VortexSessionExecute;
@@ -1164,10 +1159,11 @@ mod tests {
11641159
Ok(())
11651160
}
11661161

1167-
/// Case 2: dict with more than 256 values falls through to the standard path but
1168-
/// still produces the correct result.
1162+
/// Case 2: dict with `u16` codes (and hence more than 256 values) falls through to
1163+
/// the standard path but still produces the correct result. The direct-lookup path
1164+
/// only handles `u8` codes today.
11691165
#[test]
1170-
fn case2_more_than_256_values_falls_through() -> VortexResult<()> {
1166+
fn case2_u16_codes_falls_through() -> VortexResult<()> {
11711167
let list_size: u32 = 4;
11721168
let num_rows = 3usize;
11731169
let num_values = 300usize;
@@ -1300,5 +1296,296 @@ mod tests {
13001296
assert_close_f32(&actual, &expected, 1e-3);
13011297
Ok(())
13021298
}
1299+
1300+
// ---- Additional correctness / stress tests (all with loose tolerances) ----
1301+
1302+
/// A tiny in-place xorshift64 PRNG so these tests don't depend on `rand`. Producing
1303+
/// deterministic pseudo-random f32 values lets the correctness checks exercise
1304+
/// realistic data instead of smooth sin/cos patterns.
1305+
struct XorShift64(u64);
1306+
1307+
impl XorShift64 {
1308+
fn new(seed: u64) -> Self {
1309+
// Any nonzero seed is fine; xorshift fixed-points at 0.
1310+
Self(seed.wrapping_add(0x9E37_79B9_7F4A_7C15))
1311+
}
1312+
1313+
fn next_u64(&mut self) -> u64 {
1314+
let mut x = self.0;
1315+
x ^= x << 13;
1316+
x ^= x >> 7;
1317+
x ^= x << 17;
1318+
self.0 = x;
1319+
x
1320+
}
1321+
1322+
/// Uniform f32 in `[-1.0, 1.0)`.
1323+
fn next_f32(&mut self) -> f32 {
1324+
// Top 24 bits -> mantissa in [0, 1), then shift to [-1, 1).
1325+
let bits = (self.next_u64() >> 40) as u32; // 24 bits
1326+
(bits as f32) / (1u32 << 24) as f32 * 2.0 - 1.0
1327+
}
1328+
}
1329+
1330+
/// Case 2 stress: u8-coded dict with 200 centroids (formerly blocked by the
1331+
/// `values.len() <= 256` gate). The direct-lookup path must now handle it.
1332+
#[test]
1333+
fn case2_large_u8_codebook_direct_lookup() -> VortexResult<()> {
1334+
let list_size: u32 = 16;
1335+
let num_rows = 20usize;
1336+
let num_centroids = 200usize;
1337+
assert!(num_centroids > 8 && num_centroids <= 256);
1338+
1339+
let mut rng = XorShift64::new(0xDEAD_BEEF);
1340+
let values: Vec<f32> = (0..num_centroids).map(|_| rng.next_f32()).collect();
1341+
let codes: Vec<u8> = (0..num_rows * list_size as usize)
1342+
.map(|_| (rng.next_u64() % num_centroids as u64) as u8)
1343+
.collect();
1344+
1345+
let dict_lhs = dict_vector_f32(list_size, &codes, &values)?;
1346+
let query: Vec<f32> = (0..list_size).map(|_| rng.next_f32()).collect();
1347+
let const_rhs = constant_vector_f32(&query, num_rows)?;
1348+
1349+
let expected: Vec<f32> = (0..num_rows)
1350+
.map(|row| {
1351+
let mut acc = 0.0f32;
1352+
for j in 0..list_size as usize {
1353+
let k = codes[row * list_size as usize + j] as usize;
1354+
acc += query[j] * values[k];
1355+
}
1356+
acc
1357+
})
1358+
.collect();
1359+
1360+
let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?;
1361+
assert_close_f32(&actual, &expected, 1e-4);
1362+
Ok(())
1363+
}
1364+
1365+
/// Parameterized sweep over the full `InnerProduct(SorfTransform(Vector[FSL(Dict)]),
1366+
/// ConstantArray)` tree, exercising the case 1 + case 2 chain for a realistic mix
1367+
/// of dimensions, row counts, seeds, and number of SORF rounds. Tolerance is
1368+
/// deliberately loose because the rewrite introduces an f32-domain rotation that
1369+
/// accumulates a small numerical drift versus a naive decode.
1370+
#[rstest]
1371+
#[case::small_no_pad(128, 11, 1, 1)]
1372+
#[case::small_no_pad_rounds3(128, 23, 1_234, 3)]
1373+
#[case::small_padded(100, 17, 42, 3)]
1374+
#[case::mid_padded(200, 13, 2024, 3)]
1375+
#[case::mid_power_of_two(256, 31, 7, 3)]
1376+
#[case::larger_padded(300, 9, 99, 3)]
1377+
#[case::max_rounds(128, 5, 31_415, 5)]
1378+
fn case1_sorf_random_sweep(
1379+
#[case] dim: u32,
1380+
#[case] num_rows: usize,
1381+
#[case] seed: u64,
1382+
#[case] num_rounds: u8,
1383+
) -> VortexResult<()> {
1384+
let (sorf, codes, values, padded_dim) =
1385+
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
1386+
1387+
// Use a pseudo-random query with both positive and negative entries so the sum
1388+
// has cancellation.
1389+
let mut rng = XorShift64::new(seed ^ 0xABCD_1234);
1390+
let query: Vec<f32> = (0..dim).map(|_| rng.next_f32()).collect();
1391+
let const_rhs = constant_vector_f32(&query, num_rows)?;
1392+
1393+
let decoded = decode_sorf_dict(
1394+
&codes,
1395+
&values,
1396+
padded_dim,
1397+
dim as usize,
1398+
num_rows,
1399+
seed,
1400+
num_rounds,
1401+
)?;
1402+
let expected: Vec<f32> = (0..num_rows)
1403+
.map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query))
1404+
.collect();
1405+
1406+
// Loose tolerance: the sorf transform works in f32 with a k-round butterfly, so
1407+
// the rewrite path and the decoded path accumulate slightly different rounding
1408+
// even though the math is equivalent in exact arithmetic.
1409+
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
1410+
assert_close_f32(&actual, &expected, 1e-2);
1411+
Ok(())
1412+
}
1413+
1414+
/// Parameterized sweep over plain `Vector[FSL(Dict(u8, f32))]` + constant query,
1415+
/// without SorfTransform in the mix. This directly exercises case 2 across a
1416+
/// variety of list sizes, num_rows, and codebook sizes including large ones that
1417+
/// the old `<= 256` gate would have rejected.
1418+
#[rstest]
1419+
#[case::small(4, 7, 8)]
1420+
#[case::medium(16, 50, 64)]
1421+
#[case::larger(32, 100, 150)]
1422+
#[case::very_large_codebook(8, 25, 250)]
1423+
fn case2_random_sweep(
1424+
#[case] list_size: u32,
1425+
#[case] num_rows: usize,
1426+
#[case] num_centroids: usize,
1427+
) -> VortexResult<()> {
1428+
let mut rng = XorShift64::new((list_size as u64) * 31 + num_rows as u64);
1429+
let values: Vec<f32> = (0..num_centroids).map(|_| rng.next_f32()).collect();
1430+
assert!(num_centroids <= 256, "u8 codes cap at 256 centroids");
1431+
let codes: Vec<u8> = (0..num_rows * list_size as usize)
1432+
.map(|_| (rng.next_u64() % num_centroids as u64) as u8)
1433+
.collect();
1434+
1435+
let dict_lhs = dict_vector_f32(list_size, &codes, &values)?;
1436+
let query: Vec<f32> = (0..list_size).map(|_| rng.next_f32()).collect();
1437+
let const_rhs = constant_vector_f32(&query, num_rows)?;
1438+
1439+
let expected: Vec<f32> = (0..num_rows)
1440+
.map(|row| {
1441+
let mut acc = 0.0f32;
1442+
for j in 0..list_size as usize {
1443+
let k = codes[row * list_size as usize + j] as usize;
1444+
acc += query[j] * values[k];
1445+
}
1446+
acc
1447+
})
1448+
.collect();
1449+
1450+
// Tight tolerance here because no SorfTransform rotation is involved — the
1451+
// arithmetic should agree bit-for-bit up to float reassociation.
1452+
let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?;
1453+
assert_close_f32(&actual, &expected, 1e-4);
1454+
Ok(())
1455+
}
1456+
1457+
/// End-to-end regression: for a plausible vector-search configuration (SORF rounds
1458+
/// = 3, dim = 128, num_rows = 64, u8 codes, 64 centroids), the fast-path result
1459+
/// must track a fully naive computation within 1e-2.
1460+
#[test]
1461+
fn end_to_end_dim128_rows64_bit6_regression() -> VortexResult<()> {
1462+
let dim: u32 = 128;
1463+
let num_rows = 64usize;
1464+
let seed = 0xFACE_F00D;
1465+
let num_rounds = 3u8;
1466+
1467+
// Use 64 centroids (6 bits), a typical TurboQuant configuration.
1468+
let num_centroids = 64usize;
1469+
let padded_dim = (dim as usize).next_power_of_two();
1470+
let mut rng = XorShift64::new(seed);
1471+
let values: Vec<f32> = (0..num_centroids).map(|_| rng.next_f32()).collect();
1472+
let codes: Vec<u8> = (0..num_rows * padded_dim)
1473+
.map(|_| (rng.next_u64() % num_centroids as u64) as u8)
1474+
.collect();
1475+
1476+
let padded_vector = dict_vector_f32(padded_dim as u32, &codes, &values)?;
1477+
let sorf_options = SorfOptions {
1478+
seed,
1479+
num_rounds,
1480+
dimension: dim,
1481+
element_ptype: PType::F32,
1482+
};
1483+
let sorf =
1484+
SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array();
1485+
1486+
let query: Vec<f32> = (0..dim).map(|_| rng.next_f32()).collect();
1487+
let const_rhs = constant_vector_f32(&query, num_rows)?;
1488+
1489+
let decoded = decode_sorf_dict(
1490+
&codes,
1491+
&values,
1492+
padded_dim,
1493+
dim as usize,
1494+
num_rows,
1495+
seed,
1496+
num_rounds,
1497+
)?;
1498+
let expected: Vec<f32> = (0..num_rows)
1499+
.map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query))
1500+
.collect();
1501+
1502+
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
1503+
assert_close_f32(&actual, &expected, 1e-2);
1504+
1505+
// Also verify the max relative error is small. The SORF rotation does not
1506+
// amplify error, so both measures should be bounded.
1507+
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
1508+
let denom = e.abs().max(1.0);
1509+
let rel = (a - e).abs() / denom;
1510+
assert!(
1511+
rel < 1e-3,
1512+
"row {i}: rel err {rel} too large (a={a}, e={e})"
1513+
);
1514+
}
1515+
Ok(())
1516+
}
1517+
1518+
/// Case 1 + Case 2 end-to-end with varying `num_rounds`. The rotation becomes
1519+
/// progressively more chaotic as rounds increase, so this catches any off-by-one
1520+
/// bug in the round-indexing that would not show up in the 3-round default.
1521+
#[rstest]
1522+
#[case(1)]
1523+
#[case(2)]
1524+
#[case(3)]
1525+
#[case(4)]
1526+
#[case(5)]
1527+
fn case1_various_num_rounds(#[case] num_rounds: u8) -> VortexResult<()> {
1528+
let dim: u32 = 128;
1529+
let num_rows = 8usize;
1530+
let seed = 0x1234_5678;
1531+
1532+
let (sorf, codes, values, padded_dim) =
1533+
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
1534+
1535+
let mut rng = XorShift64::new(seed ^ (num_rounds as u64));
1536+
let query: Vec<f32> = (0..dim).map(|_| rng.next_f32()).collect();
1537+
let const_rhs = constant_vector_f32(&query, num_rows)?;
1538+
1539+
let decoded = decode_sorf_dict(
1540+
&codes,
1541+
&values,
1542+
padded_dim,
1543+
dim as usize,
1544+
num_rows,
1545+
seed,
1546+
num_rounds,
1547+
)?;
1548+
let expected: Vec<f32> = (0..num_rows)
1549+
.map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query))
1550+
.collect();
1551+
1552+
let actual = eval_ip_f32(sorf, const_rhs, num_rows)?;
1553+
assert_close_f32(&actual, &expected, 1e-2);
1554+
Ok(())
1555+
}
1556+
1557+
/// Swap LHS and RHS on the full tree to prove the side-detection and the scalar
1558+
/// argument-order handling are symmetric for both cases simultaneously.
1559+
#[test]
1560+
fn end_to_end_constant_lhs_sorf_rhs_mirrored() -> VortexResult<()> {
1561+
let dim: u32 = 256;
1562+
let num_rows = 12usize;
1563+
let seed = 0xBEEF_CAFE;
1564+
let num_rounds = 3u8;
1565+
1566+
let (sorf, codes, values, padded_dim) =
1567+
build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?;
1568+
1569+
let mut rng = XorShift64::new(seed);
1570+
let query: Vec<f32> = (0..dim).map(|_| rng.next_f32()).collect();
1571+
let const_lhs = constant_vector_f32(&query, num_rows)?;
1572+
1573+
let decoded = decode_sorf_dict(
1574+
&codes,
1575+
&values,
1576+
padded_dim,
1577+
dim as usize,
1578+
num_rows,
1579+
seed,
1580+
num_rounds,
1581+
)?;
1582+
let expected: Vec<f32> = (0..num_rows)
1583+
.map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query))
1584+
.collect();
1585+
1586+
let actual = eval_ip_f32(const_lhs, sorf, num_rows)?;
1587+
assert_close_f32(&actual, &expected, 1e-2);
1588+
Ok(())
1589+
}
13031590
}
13041591
}

0 commit comments

Comments
 (0)