|
| 1 | +/// Compiled Lisp expression fast paths for the traversal hot loops. |
| 2 | +/// |
| 3 | +/// Evaluation of Lisp predicates and weight expressions through `Interpreter::new()` |
| 4 | +/// is expensive: it allocates a fresh interpreter and clones the `SerdeExpr` tree on |
| 5 | +/// every call. For the overwhelmingly common single-expression forms we skip the |
| 6 | +/// interpreter entirely and evaluate inline against `SharedValue` map data. |
| 7 | +/// |
| 8 | +/// Two compiled types are provided: |
| 9 | +/// |
| 10 | +/// * [`CompiledPredicate`] — a boolean `(op field const)` filter used by both |
| 11 | +/// vertex and edge filter paths. |
| 12 | +/// * [`CompiledWeightExpr`] — a scalar weight computation covering the most common |
| 13 | +/// weight expression patterns without interpreter overhead. |
| 14 | +/// |
| 15 | +/// Both fall back to the full Lisp interpreter for anything that cannot be compiled |
| 16 | +/// to a fast path, preserving correctness for complex expressions. |
| 17 | +use bifrost_hasher::hash_str; |
| 18 | +use dovahkiin::expr::serde::Expr as SerdeExpr; |
| 19 | +use dovahkiin::types::{Map, OwnedMap, OwnedValue}; |
| 20 | +use neb::dovahkiin::expr::interpreter::Interpreter; |
| 21 | +use neb::dovahkiin::expr::symbols::utils::is_true; |
| 22 | +use neb::dovahkiin::expr::{SExpr, Value}; |
| 23 | +use neb::ram::types::SharedValue; |
| 24 | +use std::sync::LazyLock; |
| 25 | + |
| 26 | +// ─── Shared primitives ──────────────────────────────────────────────────────── |
| 27 | + |
| 28 | +/// Hash of the synthetic `"weight"` field injected for body-less edges. |
| 29 | +pub static WEIGHT_FIELD_ID: LazyLock<u64> = LazyLock::new(|| hash_str("weight")); |
| 30 | + |
| 31 | +/// Comparison operator for `CompiledPredicate`. |
| 32 | +#[derive(Clone, Copy, Debug)] |
| 33 | +pub enum CmpOp { |
| 34 | + Eq, |
| 35 | + Ne, |
| 36 | + Gt, |
| 37 | + Gte, |
| 38 | + Lt, |
| 39 | + Lte, |
| 40 | +} |
| 41 | + |
| 42 | +impl CmpOp { |
| 43 | + pub fn from_str(s: &str) -> Option<Self> { |
| 44 | + match s { |
| 45 | + "=" => Some(CmpOp::Eq), |
| 46 | + "!=" => Some(CmpOp::Ne), |
| 47 | + ">" => Some(CmpOp::Gt), |
| 48 | + ">=" => Some(CmpOp::Gte), |
| 49 | + "<" => Some(CmpOp::Lt), |
| 50 | + "<=" => Some(CmpOp::Lte), |
| 51 | + _ => None, |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + /// Compare two `SharedValue`s inline. Zero allocation. |
| 56 | + pub fn compare(self, lhs: &SharedValue<'_>, rhs: &SharedValue<'_>) -> bool { |
| 57 | + match self { |
| 58 | + CmpOp::Eq => lhs == rhs, |
| 59 | + CmpOp::Ne => lhs != rhs, |
| 60 | + CmpOp::Gt => sv_cmp(lhs, rhs).map_or(false, |o| o == std::cmp::Ordering::Greater), |
| 61 | + CmpOp::Gte => sv_cmp(lhs, rhs).map_or(false, |o| o != std::cmp::Ordering::Less), |
| 62 | + CmpOp::Lt => sv_cmp(lhs, rhs).map_or(false, |o| o == std::cmp::Ordering::Less), |
| 63 | + CmpOp::Lte => sv_cmp(lhs, rhs).map_or(false, |o| o != std::cmp::Ordering::Greater), |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + /// Compare an f64 (body-less edge weight) against a constant `SharedValue` RHS. |
| 68 | + pub fn compare_f64(self, lhs: f64, rhs: &SharedValue<'_>) -> bool { |
| 69 | + let rhs_f = match rhs { |
| 70 | + SharedValue::F64(v) => **v, |
| 71 | + SharedValue::F32(v) => **v as f64, |
| 72 | + SharedValue::I64(v) => **v as f64, |
| 73 | + SharedValue::I32(v) => **v as f64, |
| 74 | + SharedValue::U64(v) => **v as f64, |
| 75 | + SharedValue::U32(v) => **v as f64, |
| 76 | + _ => return false, |
| 77 | + }; |
| 78 | + match self { |
| 79 | + CmpOp::Eq => lhs == rhs_f, |
| 80 | + CmpOp::Ne => lhs != rhs_f, |
| 81 | + CmpOp::Gt => lhs > rhs_f, |
| 82 | + CmpOp::Gte => lhs >= rhs_f, |
| 83 | + CmpOp::Lt => lhs < rhs_f, |
| 84 | + CmpOp::Lte => lhs <= rhs_f, |
| 85 | + } |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +/// Numeric ordering for `SharedValue` pairs of the same type. |
| 90 | +/// Returns `None` for non-numeric or type-mismatched pairs. |
| 91 | +pub fn sv_cmp(lhs: &SharedValue<'_>, rhs: &SharedValue<'_>) -> Option<std::cmp::Ordering> { |
| 92 | + macro_rules! cmp_variant { |
| 93 | + ($($v:ident),*) => { |
| 94 | + $( |
| 95 | + if let (SharedValue::$v(a), SharedValue::$v(b)) = (lhs, rhs) { |
| 96 | + return (**a).partial_cmp(*b); |
| 97 | + } |
| 98 | + )* |
| 99 | + }; |
| 100 | + } |
| 101 | + cmp_variant!(I8, I16, I32, I64, U8, U16, U32, U64, F32, F64); |
| 102 | + None |
| 103 | +} |
| 104 | + |
| 105 | +/// Extract a named field from a `SharedValue` map and convert to `f64`. |
| 106 | +pub fn field_as_f64(sv: Option<&SharedValue<'_>>, field_id: u64) -> Option<f64> { |
| 107 | + if let Some(SharedValue::Map(m)) = sv { |
| 108 | + return sv_to_f64(m.get_by_key_id(field_id)); |
| 109 | + } |
| 110 | + None |
| 111 | +} |
| 112 | + |
| 113 | +/// Convert a `SharedValue` scalar to `f64`. Returns `None` for non-numeric types. |
| 114 | +pub fn sv_to_f64(sv: &SharedValue<'_>) -> Option<f64> { |
| 115 | + match sv { |
| 116 | + SharedValue::F64(v) => Some(**v), |
| 117 | + SharedValue::F32(v) => Some(**v as f64), |
| 118 | + SharedValue::I64(v) => Some(**v as f64), |
| 119 | + SharedValue::I32(v) => Some(**v as f64), |
| 120 | + SharedValue::U64(v) => Some(**v as f64), |
| 121 | + SharedValue::U32(v) => Some(**v as f64), |
| 122 | + SharedValue::U16(v) => Some(**v as f64), |
| 123 | + SharedValue::U8(v) => Some(**v as f64), |
| 124 | + _ => None, |
| 125 | + } |
| 126 | +} |
| 127 | + |
| 128 | +// ─── Interpreter fallback ───────────────────────────────────────────────────── |
| 129 | + |
| 130 | +/// Core interpreter evaluation: binds `shared` as global_val and evaluates `filter`. |
| 131 | +/// |
| 132 | +/// SAFETY: `interp` and `shared` both live in this stack frame. |
| 133 | +pub fn eval_with_shared_val(filter: &[SerdeExpr], shared: &SharedValue<'_>) -> bool { |
| 134 | + let exprs: Vec<_> = filter.iter().cloned().map(|e| e.to_sexpr()).collect(); |
| 135 | + let mut interp = Interpreter::new(); |
| 136 | + unsafe { interp.unsafe_set_global_val(shared) }; |
| 137 | + matches!(interp.eval(exprs), Ok(result) if is_true(&result)) |
| 138 | +} |
| 139 | + |
| 140 | +/// Interpreter fallback for weight expressions that were not compiled to a fast path. |
| 141 | +pub fn eval_weight_expr_interp( |
| 142 | + exprs: &[SerdeExpr], |
| 143 | + schema_weight: f64, |
| 144 | + edge_sv: Option<&SharedValue<'_>>, |
| 145 | +) -> f64 { |
| 146 | + let sexpr_list: Vec<SExpr> = exprs.iter().cloned().map(|e| e.to_sexpr()).collect(); |
| 147 | + let mut interp = Interpreter::new(); |
| 148 | + interp.bind( |
| 149 | + "schema_weight", |
| 150 | + SExpr::Value(Value::Owned(OwnedValue::F64(schema_weight))), |
| 151 | + ); |
| 152 | + if let Some(sv) = edge_sv { |
| 153 | + unsafe { interp.unsafe_set_global_val(sv) }; |
| 154 | + } |
| 155 | + match interp.eval(sexpr_list) { |
| 156 | + Ok(result) => sv_to_f64( |
| 157 | + &result |
| 158 | + .shared_val() |
| 159 | + .unwrap_or(SharedValue::Null), |
| 160 | + ) |
| 161 | + .unwrap_or(schema_weight), |
| 162 | + Err(_) => schema_weight, |
| 163 | + } |
| 164 | +} |
| 165 | + |
| 166 | +// ─── CompiledPredicate ──────────────────────────────────────────────────────── |
| 167 | + |
| 168 | +/// A pre-compiled boolean predicate for vertex and edge filters. |
| 169 | +/// |
| 170 | +/// Handles the common `(op field_name constant)` pattern with zero heap allocation |
| 171 | +/// per evaluation. Complex expressions fall back to the Lisp interpreter. |
| 172 | +#[derive(Clone)] |
| 173 | +pub enum CompiledPredicate { |
| 174 | + /// Fast path: `(op field_name constant)`. Zero allocation per eval. |
| 175 | + FieldCmp { |
| 176 | + field_id: u64, |
| 177 | + op: CmpOp, |
| 178 | + /// Constant RHS value, owned once at compile time. |
| 179 | + rhs: OwnedValue, |
| 180 | + }, |
| 181 | + /// Fallback: full Lisp interpreter. Allocates per evaluation. |
| 182 | + Expr(Vec<SerdeExpr>), |
| 183 | +} |
| 184 | + |
| 185 | +impl CompiledPredicate { |
| 186 | + /// Compile `exprs` into the fast path if possible, otherwise fall back to `Expr`. |
| 187 | + pub fn compile(exprs: &[SerdeExpr]) -> Self { |
| 188 | + if let [SerdeExpr::List(inner)] = exprs { |
| 189 | + if let [SerdeExpr::Symbol(_, op_str), SerdeExpr::Symbol(field_id, _), SerdeExpr::Value(rhs)] = |
| 190 | + inner.as_slice() |
| 191 | + { |
| 192 | + if let Some(op) = CmpOp::from_str(op_str) { |
| 193 | + return Self::FieldCmp { |
| 194 | + field_id: *field_id, |
| 195 | + op, |
| 196 | + rhs: rhs.clone(), |
| 197 | + }; |
| 198 | + } |
| 199 | + } |
| 200 | + } |
| 201 | + Self::Expr(exprs.to_vec()) |
| 202 | + } |
| 203 | + |
| 204 | + /// Evaluate the predicate against a `SharedValue` (vertex or edge body data). |
| 205 | + /// Zero allocation for `FieldCmp`. |
| 206 | + pub fn eval(&self, sv: &SharedValue<'_>) -> bool { |
| 207 | + match self { |
| 208 | + Self::FieldCmp { field_id, op, rhs } => { |
| 209 | + if let SharedValue::Map(m) = sv { |
| 210 | + op.compare(m.get_by_key_id(*field_id), &rhs.shared()) |
| 211 | + } else { |
| 212 | + false |
| 213 | + } |
| 214 | + } |
| 215 | + Self::Expr(filter) => eval_with_shared_val(filter, sv), |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + /// Evaluate against a body-less edge where only the synthetic `"weight"` field |
| 220 | + /// is available (passed as `weight: f64`). |
| 221 | + pub fn eval_bodyless_edge(&self, weight: f64) -> bool { |
| 222 | + match self { |
| 223 | + Self::FieldCmp { field_id, op, rhs } => { |
| 224 | + if *field_id == *WEIGHT_FIELD_ID { |
| 225 | + op.compare_f64(weight, &rhs.shared()) |
| 226 | + } else { |
| 227 | + false // fail-closed: field not available on body-less edge |
| 228 | + } |
| 229 | + } |
| 230 | + Self::Expr(filter) => { |
| 231 | + let mut map = OwnedMap::new(); |
| 232 | + map.insert("weight", OwnedValue::F64(weight)); |
| 233 | + let owned = OwnedValue::Map(map); |
| 234 | + eval_with_shared_val(filter, &owned.shared()) |
| 235 | + } |
| 236 | + } |
| 237 | + } |
| 238 | +} |
| 239 | + |
| 240 | +// ─── CompiledWeightExpr ─────────────────────────────────────────────────────── |
| 241 | + |
| 242 | +/// A pre-compiled weight expression for `WeightSpec::Expression`. |
| 243 | +/// |
| 244 | +/// Covers the most common patterns without interpreter overhead: |
| 245 | +/// |
| 246 | +/// | Expression form | Compiled variant | |
| 247 | +/// |---------------------------------------|-------------------| |
| 248 | +/// | `field_name` (symbol) | `FieldOnly` | |
| 249 | +/// | `(* schema_weight field)` or reversed | `ScaleField` | |
| 250 | +/// | `(* constant field)` or reversed | `ScaleField` | |
| 251 | +/// | `3.14` (literal float) | `Constant` | |
| 252 | +/// | anything else | `Expr` (fallback) | |
| 253 | +#[derive(Clone)] |
| 254 | +pub enum CompiledWeightExpr { |
| 255 | + /// Read a named field as f64, falling back to `schema_weight` if absent. |
| 256 | + FieldOnly { field_id: u64 }, |
| 257 | + /// Multiply a named field by a constant factor. |
| 258 | + ScaleField { field_id: u64, factor: f64 }, |
| 259 | + /// Always return this constant weight. |
| 260 | + Constant(f64), |
| 261 | + /// Full interpreter fallback. |
| 262 | + Expr(Vec<SerdeExpr>), |
| 263 | +} |
| 264 | + |
| 265 | +impl CompiledWeightExpr { |
| 266 | + /// Compile `exprs` into the fast path, with `schema_weight` inlined for |
| 267 | + /// expressions that reference it symbolically. |
| 268 | + pub fn compile(exprs: &[SerdeExpr], schema_weight: f64) -> Self { |
| 269 | + match exprs { |
| 270 | + // `field_name` — bare symbol used as weight |
| 271 | + [SerdeExpr::Symbol(field_id, _)] => Self::FieldOnly { field_id: *field_id }, |
| 272 | + |
| 273 | + // `3.14` — literal constant |
| 274 | + [SerdeExpr::Value(OwnedValue::F64(c))] => Self::Constant(*c), |
| 275 | + [SerdeExpr::Value(OwnedValue::F32(c))] => Self::Constant(*c as f64), |
| 276 | + |
| 277 | + // `(* a b)` — multiplication; supports schema_weight symbol or literal factor |
| 278 | + [SerdeExpr::List(inner)] => { |
| 279 | + if let [SerdeExpr::Symbol(_, op_str), a, b] = inner.as_slice() { |
| 280 | + if op_str == "*" { |
| 281 | + if let Some((field_id, factor)) = extract_scale(a, b, schema_weight) |
| 282 | + .or_else(|| extract_scale(b, a, schema_weight)) |
| 283 | + { |
| 284 | + return Self::ScaleField { field_id, factor }; |
| 285 | + } |
| 286 | + } |
| 287 | + } |
| 288 | + Self::Expr(exprs.to_vec()) |
| 289 | + } |
| 290 | + |
| 291 | + _ => Self::Expr(exprs.to_vec()), |
| 292 | + } |
| 293 | + } |
| 294 | + |
| 295 | + /// Evaluate the weight expression. Zero allocation for all non-`Expr` variants. |
| 296 | + pub fn eval(&self, schema_weight: f64, edge_sv: Option<&SharedValue<'_>>) -> f64 { |
| 297 | + match self { |
| 298 | + Self::FieldOnly { field_id } => { |
| 299 | + field_as_f64(edge_sv, *field_id).unwrap_or(schema_weight) |
| 300 | + } |
| 301 | + Self::ScaleField { field_id, factor } => { |
| 302 | + field_as_f64(edge_sv, *field_id) |
| 303 | + .map(|v| v * factor) |
| 304 | + .unwrap_or(schema_weight) |
| 305 | + } |
| 306 | + Self::Constant(c) => *c, |
| 307 | + Self::Expr(exprs) => eval_weight_expr_interp(exprs, schema_weight, edge_sv), |
| 308 | + } |
| 309 | + } |
| 310 | +} |
| 311 | + |
| 312 | +/// Try to extract `(field_symbol, factor)` from `(factor_expr, field_expr)`. |
| 313 | +/// `factor_expr` may be the `schema_weight` symbol or a numeric literal. |
| 314 | +fn extract_scale( |
| 315 | + factor_expr: &SerdeExpr, |
| 316 | + field_expr: &SerdeExpr, |
| 317 | + schema_weight: f64, |
| 318 | +) -> Option<(u64, f64)> { |
| 319 | + let field_id = match field_expr { |
| 320 | + SerdeExpr::Symbol(id, _) => *id, |
| 321 | + _ => return None, |
| 322 | + }; |
| 323 | + let factor = match factor_expr { |
| 324 | + SerdeExpr::Symbol(_, name) if name == "schema_weight" => schema_weight, |
| 325 | + SerdeExpr::Value(OwnedValue::F64(v)) => *v, |
| 326 | + SerdeExpr::Value(OwnedValue::F32(v)) => *v as f64, |
| 327 | + SerdeExpr::Value(OwnedValue::I64(v)) => *v as f64, |
| 328 | + SerdeExpr::Value(OwnedValue::I32(v)) => *v as f64, |
| 329 | + _ => return None, |
| 330 | + }; |
| 331 | + Some((field_id, factor)) |
| 332 | +} |
0 commit comments