Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,15 +627,23 @@ def _render(e: exp.Expression) -> str | int | float | bool:
{k: _render(v) for k, v in signal.items()} for name, signal in self.signals if not name
]

def render_signal_calls(self) -> t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]:
return {
def render_signal_calls(self) -> EvaluatableSignals:
python_env = self.python_env
env = prepare_env(python_env)
signals_to_kwargs = {
name: {
k: seq_get(self._create_renderer(v).render() or [], 0) for k, v in kwargs.items()
}
for name, kwargs in self.signals
if name
}

return EvaluatableSignals(
signals_to_kwargs=signals_to_kwargs,
python_env=python_env,
prepared_python_env=env,
)

def render_merge_filter(
self,
*,
Expand Down Expand Up @@ -1857,6 +1865,15 @@ class AuditResult(PydanticModel):
blocking: bool = True


class EvaluatableSignals(PydanticModel):
signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]
"""A mapping of signal names to the kwargs passed to the signal."""
python_env: t.Dict[str, Executable]
"""The Python environment that should be used to evaluated the rendered signal calls."""
prepared_python_env: t.Dict[str, t.Any]
"""The prepared Python environment that should be used to evaluated the rendered signal calls."""


def _extract_blueprints(blueprints: t.Any, path: Path) -> t.List[t.Any]:
if not blueprints:
return [None]
Expand Down
90 changes: 85 additions & 5 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import typing as t
import time
from sqlglot import exp
from sqlmesh.core import constants as c
from sqlmesh.core.console import Console, get_console
Expand All @@ -24,6 +25,7 @@
snapshots_to_dag,
Intervals,
)
from sqlmesh.core.snapshot.definition import check_ready_intervals
from sqlmesh.core.snapshot.definition import (
Interval,
expand_range,
Expand All @@ -39,7 +41,16 @@
to_timestamp,
validate_date_range,
)
from sqlmesh.utils.errors import AuditError, NodeAuditsErrors, CircuitBreakerError, SQLMeshError
from sqlmesh.utils.errors import (
AuditError,
NodeAuditsErrors,
CircuitBreakerError,
SQLMeshError,
SignalEvalError,
)

if t.TYPE_CHECKING:
from sqlmesh.core.context import ExecutionContext

logger = logging.getLogger(__name__)
SnapshotToIntervals = t.Dict[Snapshot, Intervals]
Expand Down Expand Up @@ -304,12 +315,11 @@ def batch_intervals(
default_catalog=self.default_catalog,
)

intervals = snapshot.check_ready_intervals(
intervals = self._check_ready_intervals(
snapshot,
intervals,
context,
console=self.console,
default_catalog=self.default_catalog,
environment_naming_info=environment_naming_info,
environment_naming_info,
)
unready -= set(intervals)

Expand Down Expand Up @@ -709,6 +719,76 @@ def _audit_snapshot(

return audit_results

def _check_ready_intervals(
self,
snapshot: Snapshot,
intervals: Intervals,
context: ExecutionContext,
environment_naming_info: EnvironmentNamingInfo,
) -> Intervals:
"""Checks if the intervals are ready for evaluation for the given snapshot.

This implementation also includes the signal progress tracking.
Note that this will handle gaps in the provided intervals. The returned intervals
may introduce new gaps.

Args:
snapshot: The snapshot to check.
intervals: The intervals to check.
context: The context to use.
environment_naming_info: The environment naming info to use.

Returns:
The intervals that are ready for evaluation.
"""
signals = snapshot.is_model and snapshot.model.render_signal_calls()

if not signals:
return intervals

self.console.start_signal_progress(
snapshot,
self.default_catalog,
environment_naming_info or EnvironmentNamingInfo(),
)

for signal_idx, (signal_name, kwargs) in enumerate(signals.signals_to_kwargs.items()):
# Capture intervals before signal check for display
intervals_to_check = merge_intervals(intervals)

signal_start_ts = time.perf_counter()

try:
intervals = check_ready_intervals(
signals.prepared_python_env[signal_name],
intervals,
context,
python_env=signals.python_env,
dialect=snapshot.model.dialect,
path=snapshot.model._path,
kwargs=kwargs,
)
except SQLMeshError as e:
raise SignalEvalError(
f"{e} '{signal_name}' for '{snapshot.model.name}' at {snapshot.model._path}"
)

duration = time.perf_counter() - signal_start_ts

self.console.update_signal_progress(
snapshot=snapshot,
signal_name=signal_name,
signal_idx=signal_idx,
total_signals=len(signals.signals_to_kwargs),
ready_intervals=merge_intervals(intervals),
check_intervals=intervals_to_check,
duration=duration,
)

self.console.stop_signal_progress()

return intervals


def merged_missing_intervals(
snapshots: t.Collection[Snapshot],
Expand Down
54 changes: 6 additions & 48 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import sys
import time
import typing as t
from collections import defaultdict
from datetime import datetime, timedelta
Expand Down Expand Up @@ -42,16 +41,13 @@
)
from sqlmesh.utils.errors import SQLMeshError, SignalEvalError
from sqlmesh.utils.metaprogramming import (
prepare_env,
print_exception,
format_evaluated_code_exception,
Executable,
)
from sqlmesh.utils.hashing import hash_data
from sqlmesh.utils.pydantic import PydanticModel, field_validator

if t.TYPE_CHECKING:
from sqlmesh.core.console import Console
from sqlglot.dialects.dialect import DialectType
from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.context import ExecutionContext
Expand Down Expand Up @@ -971,69 +967,31 @@ def check_ready_intervals(
self,
intervals: Intervals,
context: ExecutionContext,
console: t.Optional[Console] = None,
default_catalog: t.Optional[str] = None,
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
) -> Intervals:
"""Returns a list of intervals that are considered ready by the provided signal.

Note that this will handle gaps in the provided intervals. The returned intervals
may introduce new gaps.
"""
signals = self.is_model and self.model.render_signal_calls()

if not signals:
return intervals

python_env = self.model.python_env
env = prepare_env(python_env)

if console:
console.start_signal_progress(
self,
default_catalog,
environment_naming_info or EnvironmentNamingInfo(),
)

for signal_idx, (signal_name, kwargs) in enumerate(signals.items()):
# Capture intervals before signal check for display
intervals_to_check = merge_intervals(intervals)

signal_start_ts = time.perf_counter()

for signal_name, kwargs in signals.signals_to_kwargs.items():
try:
intervals = _check_ready_intervals(
env[signal_name],
intervals = check_ready_intervals(
signals.prepared_python_env[signal_name],
intervals,
context,
python_env=python_env,
python_env=signals.python_env,
dialect=self.model.dialect,
path=self.model._path,
kwargs=kwargs,
)
except SQLMeshError as e:
print_exception(e, python_env)
raise SQLMeshError(
raise SignalEvalError(
f"{e} '{signal_name}' for '{self.model.name}' at {self.model._path}"
)

duration = time.perf_counter() - signal_start_ts

if console:
console.update_signal_progress(
snapshot=self,
signal_name=signal_name,
signal_idx=signal_idx,
total_signals=len(signals),
ready_intervals=merge_intervals(intervals),
check_intervals=intervals_to_check,
duration=duration,
)

# Stop signal progress tracking
if console:
console.stop_signal_progress()

return intervals

def categorize_as(self, category: SnapshotChangeCategory) -> None:
Expand Down Expand Up @@ -2229,7 +2187,7 @@ def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]:
return contiguous_intervals


def _check_ready_intervals(
def check_ready_intervals(
check: t.Callable,
intervals: Intervals,
context: ExecutionContext,
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,7 +2118,7 @@ def test_check_intervals(sushi_context, mocker):
):
sushi_context.check_intervals(environment="dev", no_signals=False, select_models=[])

spy = mocker.spy(sqlmesh.core.snapshot.definition, "_check_ready_intervals")
spy = mocker.spy(sqlmesh.core.snapshot.definition, "check_ready_intervals")
intervals = sushi_context.check_intervals(environment=None, no_signals=False, select_models=[])

min_intervals = 19
Expand Down
14 changes: 6 additions & 8 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
apply_auto_restatements,
display_name,
get_next_model_interval_start,
_check_ready_intervals,
check_ready_intervals,
_contiguous_intervals,
)
from sqlmesh.utils import AttributeDict
Expand Down Expand Up @@ -2540,7 +2540,7 @@ def test_contiguous_intervals():
def test_check_ready_intervals(mocker: MockerFixture):
def assert_always_signal(intervals):
assert (
_check_ready_intervals(lambda _: True, intervals, mocker.Mock(), mocker.Mock())
check_ready_intervals(lambda _: True, intervals, mocker.Mock(), mocker.Mock())
== intervals
)

Expand All @@ -2550,17 +2550,15 @@ def assert_always_signal(intervals):
assert_always_signal([(0, 1), (2, 3)])

def assert_never_signal(intervals):
assert (
_check_ready_intervals(lambda _: False, intervals, mocker.Mock(), mocker.Mock()) == []
)
assert check_ready_intervals(lambda _: False, intervals, mocker.Mock(), mocker.Mock()) == []

assert_never_signal([])
assert_never_signal([(0, 1)])
assert_never_signal([(0, 1), (1, 2)])
assert_never_signal([(0, 1), (2, 3)])

def assert_empty_signal(intervals):
assert _check_ready_intervals(lambda _: [], intervals, mocker.Mock(), mocker.Mock()) == []
assert check_ready_intervals(lambda _: [], intervals, mocker.Mock(), mocker.Mock()) == []

assert_empty_signal([])
assert_empty_signal([(0, 1)])
Expand All @@ -2577,7 +2575,7 @@ def assert_check_intervals(
):
mock = mocker.Mock()
mock.side_effect = [to_intervals(r) for r in ready]
_check_ready_intervals(mock, intervals, mocker.Mock(), mocker.Mock()) == expected
check_ready_intervals(mock, intervals, mocker.Mock(), mocker.Mock()) == expected

assert_check_intervals([], [], [])
assert_check_intervals([(0, 1)], [[]], [])
Expand Down Expand Up @@ -2618,7 +2616,7 @@ def assert_check_intervals(
)

with pytest.raises(SignalEvalError):
_check_ready_intervals(
check_ready_intervals(
lambda _: (_ for _ in ()).throw(MemoryError("Some exception")),
[(0, 1), (1, 2)],
mocker.Mock(),
Expand Down