Skip to content

Commit 821884e

Browse files
authored
Use Scalar directly for constant array metadata inlining (#6439)
Changes the metadata for `ConstantArray` to `Option<Scalar>`. #6363 added functionality for inlining the constant scalar if it was small enough in the array metadata. This change adds to that by using a scalar directly for this. It is optional since we do not always want to do this. Tests incoming! Also, I think it is worth considering if we should _not_ write the scalars to the buffers if we make this optimization. As is, this is doing double work. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 0c1ab25 commit 821884e

6 files changed

Lines changed: 140 additions & 55 deletions

File tree

encodings/fastlanes/src/for/vtable/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ impl VTable for FoRVTable {
6969
}
7070

7171
fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
72+
// Note that we **only** serialize the optional scalar value (not including the dtype).
7273
Ok(Some(ScalarValue::to_proto_bytes(metadata.value())))
7374
}
7475

vortex-array/src/arrays/constant/array.rs

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,6 @@ use vortex_scalar::Scalar;
55

66
use crate::stats::ArrayStats;
77

8-
/// Protobuf-encoded metadata for [`ConstantArray`].
9-
///
10-
/// When the serialized scalar value is small enough (see `CONSTANT_INLINE_THRESHOLD`),
11-
/// it is inlined directly in the metadata to avoid a device-to-host copy on GPU.
12-
#[derive(Clone, prost::Message)]
13-
pub struct ConstantMetadata {
14-
#[prost(optional, bytes, tag = "1")]
15-
pub(super) scalar_value: Option<Vec<u8>>,
16-
}
17-
188
#[derive(Clone, Debug)]
199
pub struct ConstantArray {
2010
pub(super) scalar: Scalar,
@@ -47,21 +37,69 @@ impl ConstantArray {
4737

4838
#[cfg(test)]
4939
mod tests {
50-
use vortex_scalar::ScalarValue;
40+
use rstest::rstest;
41+
use vortex_dtype::Nullability;
42+
use vortex_error::VortexResult;
43+
use vortex_scalar::Scalar;
44+
use vortex_session::VortexSession;
5145

52-
use super::ConstantMetadata;
53-
use crate::ProstMetadata;
54-
use crate::test_harness::check_metadata;
46+
use crate::arrays::ConstantArray;
47+
use crate::arrays::constant::vtable::CONSTANT_INLINE_THRESHOLD;
48+
use crate::arrays::constant::vtable::ConstantVTable;
49+
use crate::vtable::VTable;
5550

56-
#[cfg_attr(miri, ignore)]
57-
#[test]
58-
fn test_constant_metadata() {
59-
let scalar_bytes: Vec<u8> = ScalarValue::to_proto_bytes(Some(&ScalarValue::from(i32::MAX)));
60-
check_metadata(
61-
"constant.metadata",
62-
ProstMetadata(ConstantMetadata {
63-
scalar_value: Some(scalar_bytes),
64-
}),
51+
#[rstest]
52+
#[case::below_threshold(CONSTANT_INLINE_THRESHOLD - 1, true)]
53+
#[case::at_threshold(CONSTANT_INLINE_THRESHOLD, true)]
54+
#[case::above_threshold(CONSTANT_INLINE_THRESHOLD + 1, false)]
55+
fn test_metadata_inlining(
56+
#[case] nbytes: usize,
57+
#[case] should_inline: bool,
58+
) -> VortexResult<()> {
59+
// UTF-8 scalar `nbytes` equals the string length.
60+
let string = "x".repeat(nbytes);
61+
let array = ConstantArray::new(Scalar::from(string.as_str()), 10);
62+
let metadata = ConstantVTable::metadata(&array)?;
63+
64+
assert_eq!(
65+
metadata.is_some(),
66+
should_inline,
67+
"scalar of {nbytes} bytes: expected inlined={should_inline}"
6568
);
69+
Ok(())
70+
}
71+
72+
#[test]
73+
fn test_metadata_round_trips() -> VortexResult<()> {
74+
let scalar = Scalar::from(42i64);
75+
let array = ConstantArray::new(scalar.clone(), 5);
76+
let metadata = ConstantVTable::metadata(&array)?;
77+
78+
// Serialize and deserialize the metadata.
79+
let bytes =
80+
ConstantVTable::serialize(metadata)?.expect("serialize should produce Some bytes");
81+
let session = VortexSession::empty();
82+
let deserialized = ConstantVTable::deserialize(
83+
&bytes,
84+
&vortex_dtype::DType::Primitive(vortex_dtype::PType::I64, Nullability::NonNullable),
85+
5,
86+
&session,
87+
)?;
88+
89+
assert_eq!(deserialized.unwrap(), scalar);
90+
Ok(())
91+
}
92+
93+
#[test]
94+
fn test_empty_bytes_deserializes_to_none() -> VortexResult<()> {
95+
let session = VortexSession::empty();
96+
let metadata = ConstantVTable::deserialize(
97+
&[],
98+
&vortex_dtype::DType::Primitive(vortex_dtype::PType::I32, Nullability::NonNullable),
99+
10,
100+
&session,
101+
)?;
102+
assert!(metadata.is_none(), "empty bytes should deserialize to None");
103+
Ok(())
66104
}
67105
}

vortex-array/src/arrays/constant/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ pub use arbitrary::ArbitraryConstantArray;
88

99
mod array;
1010
pub use array::ConstantArray;
11-
pub(crate) use array::ConstantMetadata;
1211
pub(crate) use vtable::canonical::constant_canonicalize;
1312

1413
pub(crate) mod compute;

vortex-array/src/arrays/constant/vtable/mod.rs

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,15 @@ use std::fmt::Debug;
55

66
use vortex_dtype::DType;
77
use vortex_error::VortexResult;
8-
use vortex_error::vortex_bail;
98
use vortex_error::vortex_ensure;
109
use vortex_scalar::Scalar;
1110
use vortex_scalar::ScalarValue;
1211
use vortex_session::VortexSession;
1312

1413
use crate::ArrayRef;
15-
use crate::DeserializeMetadata;
1614
use crate::ExecutionCtx;
1715
use crate::IntoArray;
18-
use crate::ProstMetadata;
19-
use crate::SerializeMetadata;
2016
use crate::arrays::ConstantArray;
21-
use crate::arrays::constant::ConstantMetadata;
2217
use crate::arrays::constant::compute::rules::PARENT_RULES;
2318
use crate::arrays::constant::vtable::canonical::constant_canonicalize;
2419
use crate::buffer::BufferHandle;
@@ -44,12 +39,20 @@ impl ConstantVTable {
4439

4540
/// Maximum size (in bytes) of a protobuf-encoded scalar value that will be inlined
4641
/// into the array metadata. Values larger than this are stored only in the buffer.
47-
const CONSTANT_INLINE_THRESHOLD: usize = 1024;
42+
pub(crate) const CONSTANT_INLINE_THRESHOLD: usize = 1024;
4843

4944
impl VTable for ConstantVTable {
5045
type Array = ConstantArray;
5146

52-
type Metadata = ProstMetadata<ConstantMetadata>;
47+
/// Optional inlined scalar constant.
48+
///
49+
/// When the scalar value is small enough (<= `CONSTANT_INLINE_THRESHOLD` bytes), it is stored
50+
/// directly in the metadata to avoid an extra buffer allocation and potential
51+
/// device-to-host copy during deserialization.
52+
///
53+
/// Currently, scalars are **always** stored in a separate buffer, regardless of if we inline a
54+
/// small scalar into the metadata.
55+
type Metadata = Option<Scalar>;
5356

5457
type ArrayVTable = Self;
5558
type OperationsVTable = Self;
@@ -61,28 +64,34 @@ impl VTable for ConstantVTable {
6164
}
6265

6366
fn metadata(array: &ConstantArray) -> VortexResult<Self::Metadata> {
64-
let constant = &array.scalar();
65-
let proto_bytes: Vec<u8> = ScalarValue::to_proto_bytes(constant.value());
66-
let scalar_value = (proto_bytes.len() <= CONSTANT_INLINE_THRESHOLD).then_some(proto_bytes);
67-
Ok(ProstMetadata(ConstantMetadata { scalar_value }))
67+
let constant = array.scalar();
68+
69+
// If the scalar is small enough, we can simply carry it around as metadata.
70+
Ok((constant.nbytes() <= CONSTANT_INLINE_THRESHOLD).then_some(constant.clone()))
6871
}
6972

7073
fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
71-
Ok(Some(metadata.serialize()))
74+
// If we do not have a scalar to serialize, just return empty bytes.
75+
Ok(Some(metadata.map_or_else(Vec::new, |c| {
76+
// Note that we **only** serialize the optional scalar value (not including the dtype).
77+
ScalarValue::to_proto_bytes(c.value())
78+
})))
7279
}
7380

7481
fn deserialize(
7582
bytes: &[u8],
76-
_dtype: &DType,
83+
dtype: &DType,
7784
_len: usize,
7885
_session: &VortexSession,
7986
) -> VortexResult<Self::Metadata> {
8087
// Empty bytes indicates an old writer that didn't produce metadata.
8188
if bytes.is_empty() {
82-
return Ok(ProstMetadata(ConstantMetadata { scalar_value: None }));
89+
return Ok(None);
8390
}
84-
let metadata = <Self::Metadata as DeserializeMetadata>::deserialize(bytes)?;
85-
Ok(ProstMetadata(metadata))
91+
92+
// Otherwise, deserialize the constant scalar from the metadata.
93+
let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
94+
Some(Scalar::try_new(dtype.clone(), scalar_value)).transpose()
8695
}
8796

8897
fn build(
@@ -93,22 +102,22 @@ impl VTable for ConstantVTable {
93102
_children: &dyn ArrayChildren,
94103
) -> VortexResult<ConstantArray> {
95104
// Prefer reading the scalar from inlined metadata to avoid device-to-host copies.
96-
let scalar = if let Some(proto_bytes) = &metadata.scalar_value {
97-
let scalar_value = ScalarValue::from_proto_bytes(proto_bytes, dtype)?;
98-
99-
Scalar::try_new(dtype.clone(), scalar_value)
100-
} else {
101-
if buffers.len() != 1 {
102-
vortex_bail!("Expected 1 buffer, got {}", buffers.len());
103-
}
105+
if let Some(constant) = metadata {
106+
return Ok(ConstantArray::new(constant.clone(), len));
107+
}
104108

105-
let buffer = buffers[0].clone().try_to_host_sync()?;
106-
let bytes: &[u8] = buffer.as_ref();
109+
// Otherwise, get the constant scalar from the buffers.
110+
vortex_ensure!(
111+
buffers.len() == 1,
112+
"Expected 1 buffer, got {}",
113+
buffers.len()
114+
);
107115

108-
let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
116+
let buffer = buffers[0].clone().try_to_host_sync()?;
117+
let bytes: &[u8] = buffer.as_ref();
109118

110-
Scalar::try_new(dtype.clone(), scalar_value)
111-
}?;
119+
let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
120+
let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
112121

113122
Ok(ConstantArray::new(scalar, len))
114123
}

vortex-scalar/src/scalar.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,10 @@ impl Scalar {
316316
)
317317
}
318318

319-
/// Returns the size of the scalar in bytes, uncompressed.
320-
#[cfg(test)]
319+
/// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed.
320+
///
321+
/// Note that the protobuf serialization of scalars will likely have a different (but roughly
322+
/// similar) length.
321323
pub fn nbytes(&self) -> usize {
322324
use vortex_dtype::NativeDecimalType;
323325
use vortex_dtype::i256;

vortex-scalar/src/tests/round_trip.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
mod tests {
1212
use std::sync::Arc;
1313

14+
use rstest::rstest;
1415
use vortex_buffer::ByteBuffer;
1516
use vortex_dtype::DType;
1617
use vortex_dtype::DecimalDType;
@@ -21,6 +22,7 @@ mod tests {
2122

2223
use crate::DecimalValue;
2324
use crate::Scalar;
25+
use crate::ScalarValue;
2426
use crate::tests::SESSION;
2527

2628
// Test that primitive scalars round-trip through ScalarValue
@@ -292,4 +294,38 @@ mod tests {
292294
let bool_scalar = Scalar::bool(true, Nullability::NonNullable);
293295
assert!(bool_scalar.as_decimal_opt().is_none());
294296
}
297+
298+
/// Verifies that [`Scalar::nbytes`] matches the length of the proto-serialized scalar value.
299+
#[rstest]
300+
#[case::null_i32(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)))]
301+
#[case::bool_true(Scalar::from(true))]
302+
#[case::bool_false(Scalar::from(false))]
303+
#[case::i8(Scalar::from(i8::MAX))]
304+
#[case::i16(Scalar::from(i16::MAX))]
305+
#[case::i32(Scalar::from(i32::MAX))]
306+
#[case::i64(Scalar::from(i64::MAX))]
307+
#[case::u8(Scalar::from(u8::MAX))]
308+
#[case::u16(Scalar::from(u16::MAX))]
309+
#[case::u32(Scalar::from(u32::MAX))]
310+
#[case::u64(Scalar::from(u64::MAX))]
311+
#[case::f32(Scalar::from(f32::MAX))]
312+
#[case::f64(Scalar::from(f64::MAX))]
313+
#[case::utf8_empty(Scalar::from(""))]
314+
#[case::utf8_short(Scalar::from("hello"))]
315+
#[case::utf8_long(Scalar::from("x".repeat(2048).as_str()))]
316+
#[case::binary_empty(Scalar::binary(Vec::<u8>::new(), Nullability::NonNullable))]
317+
#[case::binary_short(Scalar::binary(vec![1u8, 2, 3], Nullability::NonNullable))]
318+
fn test_nbytes_approx_eq_to_proto_bytes(#[case] scalar: Scalar) {
319+
let proto_bytes: Vec<u8> = ScalarValue::to_proto_bytes(scalar.value());
320+
let diff = (scalar.nbytes() as isize - proto_bytes.len() as isize).abs();
321+
322+
// NOTE: THE 4 HERE IS COMPLETELY ARBITRARY!!!
323+
assert!(
324+
diff <= 4,
325+
"nbytes() should be within 4 of proto-serialized length for {:?}, got {} vs {}",
326+
scalar,
327+
scalar.nbytes(),
328+
proto_bytes.len(),
329+
);
330+
}
295331
}

0 commit comments

Comments
 (0)