Skip to content

feat: add cosine_distance scalar function#21542

Open
crm26 wants to merge 2 commits intoapache:mainfrom
crm26:feat/cosine-distance
Open

feat: add cosine_distance scalar function#21542
crm26 wants to merge 2 commits intoapache:mainfrom
crm26:feat/cosine-distance

Conversation

@crm26
Copy link
Copy Markdown

@crm26 crm26 commented Apr 10, 2026

Summary

  • Adds cosine_distance(array1, array2) / list_cosine_distance — computes cosine distance (1 - cosine similarity) between two numeric arrays
  • Introduces shared vector_math.rs primitives (dot_product_f64, magnitude_f64, convert_to_f64_array) for reuse by follow-on vector functions
  • Returns NULL for zero-magnitude vectors; errors on mismatched lengths
  • Supports List, LargeList, and FixedSizeList with any numeric element type

Part of #21536 — first in a series of split PRs (replacing #21371).

Test plan

  • Unit tests: identical, orthogonal, opposite, 45-degree, zero-magnitude, mismatched-length, NULL, multi-row
  • sqllogictest: cosine_distance.slt covering all edge cases including empty arrays, LargeList, integer coercion, alias, return type
  • Full slt suite (426/426 pass)
  • cargo clippy, cargo fmt, taplo, prettier, cargo machete — all clean

🤖 Generated with Claude Code

Add cosine_distance (and list_cosine_distance alias) to compute cosine
distance between two numeric arrays. Includes shared vector math
primitives in vector_math.rs for reuse by follow-on functions.

Part of apache#21536.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions github-actions bot added documentation Improvements or additions to documentation sqllogictest SQL Logic Tests (.slt) functions Changes to functions implementation labels Apr 10, 2026
Comment on lines +152 to +157
let result = list_array1
.iter()
.zip(list_array2.iter())
.map(|(arr1, arr2)| compute_cosine_distance(arr1, arr2))
.collect::<Result<Float64Array>>()?;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably more efficient to iterate using offsets/values than needing to downcast each ArrayRef value

let mut value1 = value1;
let mut value2 = value2;

loop {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this loop for? I don't believe this function is intended to support nested lists?

Comment on lines +216 to +217
let values1 = convert_to_f64_array(&value1)?;
let values2 = convert_to_f64_array(&value2)?;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this strictly necessary? I though signature coercion should already ensure it is a float array

let values2 = convert_to_f64_array(&value2)?;

if values1.len() != values2.len() {
return exec_err!("Both arrays must have the same length");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to be more descriptive for this error, especially as this can be confused as the actual input Arrow Array (when it's meant to refer to the list type inside the Arrow Array)

}

#[cfg(test)]
mod tests {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests seem duplicated with SLTs, so should remove these Rust unit tests

Comment thread datafusion/functions-nested/src/lib.rs Outdated
pub mod sort;
pub mod string;
pub mod utils;
pub mod vector_math;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think vector_math needs to be public here

Comment on lines +47 to +67
/// Computes dot product: sum(a\[i\] * b\[i\])
pub fn dot_product_f64(a: &Float64Array, b: &Float64Array) -> f64 {
a.iter()
.zip(b.iter())
.map(|(v1, v2)| v1.unwrap_or(0.0) * v2.unwrap_or(0.0))
.sum()
}

/// Computes sum of squares: sum(a\[i\]^2)
pub fn sum_of_squares_f64(a: &Float64Array) -> f64 {
a.iter()
.map(|v| {
let val = v.unwrap_or(0.0);
val * val
})
.sum()
}

/// Computes magnitude (L2 norm): sqrt(sum(a\[i\]^2))
pub fn magnitude_f64(a: &Float64Array) -> f64 {
sum_of_squares_f64(a).sqrt()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are probably best off inlined for the moment? It might make it more efficient instead of needing to pass in an array slice for every array value

Addresses review comments on apache#21542:
- Iterate list offsets/values directly instead of per-row ArrayRef downcast
- Remove nested-list unwrap loop (function does not support nested lists)
- Drop convert_to_f64_array wrapper (coerce_types already guarantees Float64)
- Remove duplicate Rust unit tests now covered by SLT
- More descriptive error message for mismatched list lengths
- Delete now-unused vector_math module; inline math into sole caller

Adds SLT coverage for NULL-element-in-list behavior previously tested
only in Rust unit tests.

Part of apache#21536.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@Jefffrey Jefffrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good

Ok(DataType::Float64)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to keep in mind here is input lists of different types, take for example:

query R
select list_cosine_distance([1.0, 0.0], arrow_cast([0.0, 1.0], 'LargeList(Float16)'));
----
1

This was highlighted by the recent PR for concat:

pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec!["list_cosine_distance".to_string()],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_cosine_distance seems a bit unwieldy, especially considering the regular name of the function cosine_distance doesn't mention array (usually we'd have something like array_remove then the alias just replaces array, e.g. list_remove)

Comment on lines +45 to +46
query error cosine_distance does not support type
select cosine_distance(NULL, [1.0, 2.0]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably still handle this to return null

select cosine_distance(column1, column2) from (values
(make_array(1.0, 0.0), make_array(0.0, 1.0)),
(make_array(1.0, 1.0), make_array(1.0, 1.0)),
(make_array(1.0, 0.0), make_array(-1.0, 0.0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(make_array(1.0, 0.0), make_array(-1.0, 0.0))
(make_array(1.0, 0.0), make_array(-1.0, 0.0)),
(make_array(1.0, 0.0), NULL)

Just for coverage

@crm26
Copy link
Copy Markdown
Author

crm26 commented Apr 18, 2026

Thanks for the detailed review, @Jefffrey. Rework pushed in fc3ee90. Walking through each comment:

1. Iterate via offsets/values, not per-row ArrayRef downcast (cosine_distance.rs:157)
Rewrote general_cosine_distance to downcast list_array.values() once to &Float64Array, then slice by value_offsets() per row. The inner for i in 0..len loop reads directly from the contiguous ScalarBuffer<f64> — no per-row downcast, no Option<f64> unwrapping.

2. Nested-list unwrap loop (cosine_distance.rs:178)
Removed entirely. The function is not intended to support nested lists, and coerce_types rejects anything other than List/LargeList/FixedSizeList of a numeric inner type.

3. Redundant null/Float64 check (cosine_distance.rs:217)
Agreed — coerce_types calls coerced_type_with_base_type_only with Float64 as the base type, which guarantees the inner type is Float64 by the time we hit invoke_with_args. Dropped convert_to_f64_array entirely and replaced with a direct as_float64_array downcast.

4. Ambiguous length-mismatch error wording (cosine_distance.rs:220)
Updated to "cosine_distance requires both list inputs to have the same length, got {len1} and {len2}". Now explicit that the lengths are the list elements' lengths, not the outer array, and includes the observed values.

5. Duplicate Rust unit tests (cosine_distance.rs:235)
Removed the mod tests block. SLT coverage includes orthogonal, identical, opposite, 45-degree, zero-magnitude, mismatched lengths, LargeList, integer coercion, multi-row, alias, empty arrays, no-args, and return-type checks. Added one new SLT case for NULL-element-in-list to preserve that particular behavior the Rust tests were covering.

6. pub mod vector_math (lib.rs:72)
Moot after #7 — the whole file is deleted, so the declaration is gone.

7. Inline the math instead of a separate module (vector_math.rs:67)
Agreed — with only one caller, the indirection wasn't paying for itself. Deleted vector_math.rs and inlined dot/magnitude into the tight per-row loop in general_cosine_distance.

Full validation matrix (fmt --all, workspace clippy -D warnings, full + sqlite-extended SLT, CLI, doctests, feature-flag spot-checks, extended_tests workspace build, rustdoc, license, typos, machete, generated-doc regen) passed locally before push. Let me know if anything else needs tightening.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation functions Changes to functions implementation sqllogictest SQL Logic Tests (.slt)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants