Skip to content

Commit caffe3f

Browse files
committed
shared array to avoid duplicate execute
Signed-off-by: Onur Satici <onur@spiraldb.com>
1 parent 493dbb5 commit caffe3f

7 files changed

Lines changed: 50 additions & 55 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ arrow-schema = "57.1"
9999
arrow-select = "57.1"
100100
arrow-string = "57.1"
101101
async-compat = "0.2.5"
102+
async-lock = "3.4"
102103
async-fs = "2.2.0"
103104
async-stream = "0.3.6"
104105
async-trait = "0.1.89"

vortex-array/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ workspace = true
2222
[dependencies]
2323
arbitrary = { workspace = true, optional = true }
2424
arcref = { workspace = true }
25+
async-lock = { workspace = true }
2526
arrow-arith = { workspace = true }
2627
arrow-array = { workspace = true, features = ["ffi"] }
2728
arrow-buffer = { workspace = true }
Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::future::Future;
45
use std::sync::Arc;
56

6-
use parking_lot::RwLock;
7+
use async_lock::Mutex;
78
use vortex_dtype::DType;
89
use vortex_error::VortexResult;
9-
use vortex_error::vortex_panic;
1010

1111
use crate::ArrayRef;
1212
use crate::Canonical;
13-
use crate::ExecutionCtx;
1413
use crate::IntoArray;
1514
use crate::stats::ArrayStats;
1615

1716
#[derive(Debug, Clone)]
1817
pub struct SharedArray {
19-
pub(super) state: Arc<RwLock<SharedState>>,
18+
pub(super) state: Arc<Mutex<SharedState>>,
2019
pub(super) dtype: DType,
2120
pub(super) stats: ArrayStats,
2221
}
@@ -28,60 +27,57 @@ pub(super) enum SharedState {
2827
}
2928

3029
impl SharedArray {
30+
/// Creates a new `SharedArray` wrapping the given source array.
3131
pub fn new(source: ArrayRef) -> Self {
3232
Self {
3333
dtype: source.dtype().clone(),
34-
state: Arc::new(RwLock::new(SharedState::Source(source))),
34+
state: Arc::new(Mutex::new(SharedState::Source(source))),
3535
stats: ArrayStats::default(),
3636
}
3737
}
3838

39-
pub fn cached(&self) -> Option<Canonical> {
40-
match &*self.state.read() {
41-
SharedState::Cached(canonical) => Some(canonical.clone()),
42-
SharedState::Source(_) => None,
43-
}
44-
}
45-
46-
pub fn cache_or_return(&self, canonical: Canonical) -> Canonical {
47-
let mut state = self.state.write();
39+
pub fn get_or_compute(
40+
&self,
41+
f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
42+
) -> VortexResult<Canonical> {
43+
let mut state = self.state.lock_blocking();
4844
match &*state {
49-
SharedState::Cached(existing) => existing.clone(),
50-
SharedState::Source(_) => {
45+
SharedState::Cached(canonical) => Ok(canonical.clone()),
46+
SharedState::Source(source) => {
47+
let canonical = f(source)?;
5148
*state = SharedState::Cached(canonical.clone());
52-
canonical
49+
Ok(canonical)
5350
}
5451
}
5552
}
5653

57-
pub fn as_source(&self) -> ArrayRef {
58-
let SharedState::Source(source) = &*self.state.read() else {
59-
vortex_panic!("already cached");
60-
};
61-
source.clone()
62-
}
63-
64-
pub(super) fn canonicalize(&self, ctx: &mut ExecutionCtx) -> VortexResult<Canonical> {
65-
let source = {
66-
let state = self.state.read();
67-
match &*state {
68-
SharedState::Cached(existing) => return Ok(existing.clone()),
69-
SharedState::Source(source) => source.clone(),
54+
pub async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<Canonical>
55+
where
56+
F: FnOnce(ArrayRef) -> Fut,
57+
Fut: Future<Output = VortexResult<Canonical>>,
58+
{
59+
let mut state = self.state.lock().await;
60+
match &*state {
61+
SharedState::Cached(canonical) => Ok(canonical.clone()),
62+
SharedState::Source(source) => {
63+
let source = source.clone();
64+
let canonical = f(source).await?;
65+
*state = SharedState::Cached(canonical.clone());
66+
Ok(canonical)
7067
}
71-
};
72-
let canonical = source.execute::<Canonical>(ctx)?;
73-
Ok(self.cache_or_return(canonical))
68+
}
7469
}
7570

7671
pub(super) fn current_array_ref(&self) -> ArrayRef {
77-
match &*self.state.read() {
72+
let state = self.state.lock_blocking();
73+
match &*state {
7874
SharedState::Source(source) => source.clone(),
7975
SharedState::Cached(canonical) => canonical.clone().into_array(),
8076
}
8177
}
8278

8379
pub(super) fn set_source(&mut self, source: ArrayRef) {
8480
self.dtype = source.dtype().clone();
85-
*self.state.write() = SharedState::Source(source);
81+
*self.state.lock_blocking() = SharedState::Source(source);
8682
}
8783
}

vortex-array/src/arrays/shared/tests.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_buffer::buffer;
5+
use vortex_error::VortexResult;
56
use vortex_session::VortexSession;
67

8+
use crate::Canonical;
79
use crate::ExecutionCtx;
810
use crate::IntoArray;
911
use crate::arrays::PrimitiveArray;
@@ -14,28 +16,22 @@ use crate::session::ArraySession;
1416
use crate::validity::Validity;
1517

1618
#[test]
17-
fn shared_array_caches_on_canonicalize() -> vortex_error::VortexResult<()> {
19+
fn shared_array_caches_on_canonicalize() -> VortexResult<()> {
1820
let array = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::NonNullable).into_array();
1921
let shared = SharedArray::new(array);
2022

21-
assert!(shared.cached().is_none());
22-
2323
let session = VortexSession::empty().with::<ArraySession>();
2424
let mut ctx = ExecutionCtx::new(session);
2525

26-
let first = shared.canonicalize(&mut ctx)?;
27-
let cached = shared.cached().expect("canonicalize should cache result");
28-
assert!(
29-
cached
30-
.as_ref()
31-
.array_eq(first.as_ref(), HashPrecision::Value)
32-
);
26+
let first = shared.get_or_compute(|source| source.clone().execute::<Canonical>(&mut ctx))?;
27+
28+
// Second call should return cached without invoking the closure.
29+
let second = shared.get_or_compute(|_| panic!("should not execute twice"))?;
3330

34-
let second = shared.canonicalize(&mut ctx)?;
3531
assert!(
36-
second
32+
first
3733
.as_ref()
38-
.array_eq(first.as_ref(), HashPrecision::Value)
34+
.array_eq(second.as_ref(), HashPrecision::Value)
3935
);
4036

4137
Ok(())

vortex-array/src/arrays/shared/vtable.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use vortex_session::VortexSession;
1212
use crate::ArrayBufferVisitor;
1313
use crate::ArrayChildVisitor;
1414
use crate::ArrayRef;
15+
use crate::Canonical;
1516
use crate::EmptyMetadata;
1617
use crate::ExecutionCtx;
1718
use crate::IntoArray;
@@ -96,7 +97,9 @@ impl VTable for SharedVTable {
9697
}
9798

9899
fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
99-
Ok(array.canonicalize(ctx)?.into_array())
100+
Ok(array
101+
.get_or_compute(|source| source.clone().execute::<Canonical>(ctx))?
102+
.into_array())
100103
}
101104
}
102105

vortex-cuda/src/kernel/arrays/shared.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ impl CudaExecute for SharedExecutor {
2828
.ok()
2929
.vortex_expect("Array is not a Shared array");
3030

31-
if let Some(cached) = shared.cached() {
32-
return Ok(cached);
33-
}
34-
35-
let canonical = shared.as_source().execute_cuda(ctx).await?;
36-
Ok(shared.cache_or_return(canonical))
31+
shared
32+
.get_or_compute_async(|source| source.execute_cuda(ctx))
33+
.await
3734
}
3835
}

0 commit comments

Comments
 (0)