11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use std:: future:: Future ;
45use std:: sync:: Arc ;
56
6- use parking_lot :: RwLock ;
7+ use async_lock :: Mutex ;
78use vortex_dtype:: DType ;
89use vortex_error:: VortexResult ;
9- use vortex_error:: vortex_panic;
1010
1111use crate :: ArrayRef ;
1212use crate :: Canonical ;
13- use crate :: ExecutionCtx ;
1413use crate :: IntoArray ;
1514use crate :: stats:: ArrayStats ;
1615
1716#[ derive( Debug , Clone ) ]
1817pub struct SharedArray {
19- pub ( super ) state : Arc < RwLock < SharedState > > ,
18+ pub ( super ) state : Arc < Mutex < SharedState > > ,
2019 pub ( super ) dtype : DType ,
2120 pub ( super ) stats : ArrayStats ,
2221}
@@ -28,60 +27,57 @@ pub(super) enum SharedState {
2827}
2928
3029impl SharedArray {
30+ /// Creates a new `SharedArray` wrapping the given source array.
3131 pub fn new ( source : ArrayRef ) -> Self {
3232 Self {
3333 dtype : source. dtype ( ) . clone ( ) ,
34- state : Arc :: new ( RwLock :: new ( SharedState :: Source ( source) ) ) ,
34+ state : Arc :: new ( Mutex :: new ( SharedState :: Source ( source) ) ) ,
3535 stats : ArrayStats :: default ( ) ,
3636 }
3737 }
3838
39- pub fn cached ( & self ) -> Option < Canonical > {
40- match & * self . state . read ( ) {
41- SharedState :: Cached ( canonical) => Some ( canonical. clone ( ) ) ,
42- SharedState :: Source ( _) => None ,
43- }
44- }
45-
46- pub fn cache_or_return ( & self , canonical : Canonical ) -> Canonical {
47- let mut state = self . state . write ( ) ;
39+ pub fn get_or_compute (
40+ & self ,
41+ f : impl FnOnce ( & ArrayRef ) -> VortexResult < Canonical > ,
42+ ) -> VortexResult < Canonical > {
43+ let mut state = self . state . lock_blocking ( ) ;
4844 match & * state {
49- SharedState :: Cached ( existing) => existing. clone ( ) ,
50- SharedState :: Source ( _) => {
45+ SharedState :: Cached ( canonical) => Ok ( canonical. clone ( ) ) ,
46+ SharedState :: Source ( source) => {
47+ let canonical = f ( source) ?;
5148 * state = SharedState :: Cached ( canonical. clone ( ) ) ;
52- canonical
49+ Ok ( canonical)
5350 }
5451 }
5552 }
5653
57- pub fn as_source ( & self ) -> ArrayRef {
58- let SharedState :: Source ( source ) = & * self . state . read ( ) else {
59- vortex_panic ! ( "already cached" ) ;
60- } ;
61- source . clone ( )
62- }
63-
64- pub ( super ) fn canonicalize ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < Canonical > {
65- let source = {
66- let state = self . state . read ( ) ;
67- match & * state {
68- SharedState :: Cached ( existing ) => return Ok ( existing . clone ( ) ) ,
69- SharedState :: Source ( source ) => source . clone ( ) ,
54+ pub async fn get_or_compute_async < F , Fut > ( & self , f : F ) -> VortexResult < Canonical >
55+ where
56+ F : FnOnce ( ArrayRef ) -> Fut ,
57+ Fut : Future < Output = VortexResult < Canonical > > ,
58+ {
59+ let mut state = self . state . lock ( ) . await ;
60+ match & * state {
61+ SharedState :: Cached ( canonical ) => Ok ( canonical . clone ( ) ) ,
62+ SharedState :: Source ( source) => {
63+ let source = source . clone ( ) ;
64+ let canonical = f ( source ) . await ? ;
65+ * state = SharedState :: Cached ( canonical . clone ( ) ) ;
66+ Ok ( canonical )
7067 }
71- } ;
72- let canonical = source. execute :: < Canonical > ( ctx) ?;
73- Ok ( self . cache_or_return ( canonical) )
68+ }
7469 }
7570
7671 pub ( super ) fn current_array_ref ( & self ) -> ArrayRef {
77- match & * self . state . read ( ) {
72+ let state = self . state . lock_blocking ( ) ;
73+ match & * state {
7874 SharedState :: Source ( source) => source. clone ( ) ,
7975 SharedState :: Cached ( canonical) => canonical. clone ( ) . into_array ( ) ,
8076 }
8177 }
8278
8379 pub ( super ) fn set_source ( & mut self , source : ArrayRef ) {
8480 self . dtype = source. dtype ( ) . clone ( ) ;
85- * self . state . write ( ) = SharedState :: Source ( source) ;
81+ * self . state . lock_blocking ( ) = SharedState :: Source ( source) ;
8682 }
8783}
0 commit comments