diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 6dce3e6843..0808722119 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -627,8 +627,10 @@ def _render(e: exp.Expression) -> str | int | float | bool: {k: _render(v) for k, v in signal.items()} for name, signal in self.signals if not name ] - def render_signal_calls(self) -> t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]: - return { + def render_signal_calls(self) -> EvaluatableSignals: + python_env = self.python_env + env = prepare_env(python_env) + signals_to_kwargs = { name: { k: seq_get(self._create_renderer(v).render() or [], 0) for k, v in kwargs.items() } @@ -636,6 +638,12 @@ def render_signal_calls(self) -> t.Dict[str, t.Dict[str, t.Optional[exp.Expressi if name } + return EvaluatableSignals( + signals_to_kwargs=signals_to_kwargs, + python_env=python_env, + prepared_python_env=env, + ) + def render_merge_filter( self, *, @@ -1857,6 +1865,15 @@ class AuditResult(PydanticModel): blocking: bool = True +class EvaluatableSignals(PydanticModel): + signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]] + """A mapping of signal names to the kwargs passed to the signal.""" + python_env: t.Dict[str, Executable] + """The Python environment that should be used to evaluated the rendered signal calls.""" + prepared_python_env: t.Dict[str, t.Any] + """The prepared Python environment that should be used to evaluated the rendered signal calls.""" + + def _extract_blueprints(blueprints: t.Any, path: Path) -> t.List[t.Any]: if not blueprints: return [None] diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 8bac0bf081..2c7a2a66ac 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging import typing as t +import time from sqlglot import exp from sqlmesh.core import constants as c from sqlmesh.core.console import Console, get_console @@ -24,6 +25,7 @@ snapshots_to_dag, Intervals, ) +from sqlmesh.core.snapshot.definition import check_ready_intervals from sqlmesh.core.snapshot.definition import ( Interval, expand_range, @@ -39,7 +41,16 @@ to_timestamp, validate_date_range, ) -from sqlmesh.utils.errors import AuditError, NodeAuditsErrors, CircuitBreakerError, SQLMeshError +from sqlmesh.utils.errors import ( + AuditError, + NodeAuditsErrors, + CircuitBreakerError, + SQLMeshError, + SignalEvalError, +) + +if t.TYPE_CHECKING: + from sqlmesh.core.context import ExecutionContext logger = logging.getLogger(__name__) SnapshotToIntervals = t.Dict[Snapshot, Intervals] @@ -304,12 +315,11 @@ def batch_intervals( default_catalog=self.default_catalog, ) - intervals = snapshot.check_ready_intervals( + intervals = self._check_ready_intervals( + snapshot, intervals, context, - console=self.console, - default_catalog=self.default_catalog, - environment_naming_info=environment_naming_info, + environment_naming_info, ) unready -= set(intervals) @@ -709,6 +719,76 @@ def _audit_snapshot( return audit_results + def _check_ready_intervals( + self, + snapshot: Snapshot, + intervals: Intervals, + context: ExecutionContext, + environment_naming_info: EnvironmentNamingInfo, + ) -> Intervals: + """Checks if the intervals are ready for evaluation for the given snapshot. + + This implementation also includes the signal progress tracking. + Note that this will handle gaps in the provided intervals. The returned intervals + may introduce new gaps. + + Args: + snapshot: The snapshot to check. + intervals: The intervals to check. + context: The context to use. + environment_naming_info: The environment naming info to use. + + Returns: + The intervals that are ready for evaluation. + """ + signals = snapshot.is_model and snapshot.model.render_signal_calls() + + if not signals: + return intervals + + self.console.start_signal_progress( + snapshot, + self.default_catalog, + environment_naming_info or EnvironmentNamingInfo(), + ) + + for signal_idx, (signal_name, kwargs) in enumerate(signals.signals_to_kwargs.items()): + # Capture intervals before signal check for display + intervals_to_check = merge_intervals(intervals) + + signal_start_ts = time.perf_counter() + + try: + intervals = check_ready_intervals( + signals.prepared_python_env[signal_name], + intervals, + context, + python_env=signals.python_env, + dialect=snapshot.model.dialect, + path=snapshot.model._path, + kwargs=kwargs, + ) + except SQLMeshError as e: + raise SignalEvalError( + f"{e} '{signal_name}' for '{snapshot.model.name}' at {snapshot.model._path}" + ) + + duration = time.perf_counter() - signal_start_ts + + self.console.update_signal_progress( + snapshot=snapshot, + signal_name=signal_name, + signal_idx=signal_idx, + total_signals=len(signals.signals_to_kwargs), + ready_intervals=merge_intervals(intervals), + check_intervals=intervals_to_check, + duration=duration, + ) + + self.console.stop_signal_progress() + + return intervals + def merged_missing_intervals( snapshots: t.Collection[Snapshot], diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index e84a1fce27..9b5fa893fc 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1,7 +1,6 @@ from __future__ import annotations import sys -import time import typing as t from collections import defaultdict from datetime import datetime, timedelta @@ -42,8 +41,6 @@ ) from sqlmesh.utils.errors import SQLMeshError, SignalEvalError from sqlmesh.utils.metaprogramming import ( - prepare_env, - print_exception, format_evaluated_code_exception, Executable, ) @@ -51,7 +48,6 @@ from sqlmesh.utils.pydantic import PydanticModel, field_validator if t.TYPE_CHECKING: - from sqlmesh.core.console import Console from sqlglot.dialects.dialect import DialectType from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.context import ExecutionContext @@ -971,9 +967,6 @@ def check_ready_intervals( self, intervals: Intervals, context: ExecutionContext, - console: t.Optional[Console] = None, - default_catalog: t.Optional[str] = None, - environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, ) -> Intervals: """Returns a list of intervals that are considered ready by the provided signal. @@ -981,59 +974,24 @@ def check_ready_intervals( may introduce new gaps. """ signals = self.is_model and self.model.render_signal_calls() - if not signals: return intervals - python_env = self.model.python_env - env = prepare_env(python_env) - - if console: - console.start_signal_progress( - self, - default_catalog, - environment_naming_info or EnvironmentNamingInfo(), - ) - - for signal_idx, (signal_name, kwargs) in enumerate(signals.items()): - # Capture intervals before signal check for display - intervals_to_check = merge_intervals(intervals) - - signal_start_ts = time.perf_counter() - + for signal_name, kwargs in signals.signals_to_kwargs.items(): try: - intervals = _check_ready_intervals( - env[signal_name], + intervals = check_ready_intervals( + signals.prepared_python_env[signal_name], intervals, context, - python_env=python_env, + python_env=signals.python_env, dialect=self.model.dialect, path=self.model._path, kwargs=kwargs, ) except SQLMeshError as e: - print_exception(e, python_env) - raise SQLMeshError( + raise SignalEvalError( f"{e} '{signal_name}' for '{self.model.name}' at {self.model._path}" ) - - duration = time.perf_counter() - signal_start_ts - - if console: - console.update_signal_progress( - snapshot=self, - signal_name=signal_name, - signal_idx=signal_idx, - total_signals=len(signals), - ready_intervals=merge_intervals(intervals), - check_intervals=intervals_to_check, - duration=duration, - ) - - # Stop signal progress tracking - if console: - console.stop_signal_progress() - return intervals def categorize_as(self, category: SnapshotChangeCategory) -> None: @@ -2229,7 +2187,7 @@ def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]: return contiguous_intervals -def _check_ready_intervals( +def check_ready_intervals( check: t.Callable, intervals: Intervals, context: ExecutionContext, diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 8922472aa3..7024e9f73f 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -2118,7 +2118,7 @@ def test_check_intervals(sushi_context, mocker): ): sushi_context.check_intervals(environment="dev", no_signals=False, select_models=[]) - spy = mocker.spy(sqlmesh.core.snapshot.definition, "_check_ready_intervals") + spy = mocker.spy(sqlmesh.core.snapshot.definition, "check_ready_intervals") intervals = sushi_context.check_intervals(environment=None, no_signals=False, select_models=[]) min_intervals = 19 diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index f09083f500..f2f54822f5 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -59,7 +59,7 @@ apply_auto_restatements, display_name, get_next_model_interval_start, - _check_ready_intervals, + check_ready_intervals, _contiguous_intervals, ) from sqlmesh.utils import AttributeDict @@ -2540,7 +2540,7 @@ def test_contiguous_intervals(): def test_check_ready_intervals(mocker: MockerFixture): def assert_always_signal(intervals): assert ( - _check_ready_intervals(lambda _: True, intervals, mocker.Mock(), mocker.Mock()) + check_ready_intervals(lambda _: True, intervals, mocker.Mock(), mocker.Mock()) == intervals ) @@ -2550,9 +2550,7 @@ def assert_always_signal(intervals): assert_always_signal([(0, 1), (2, 3)]) def assert_never_signal(intervals): - assert ( - _check_ready_intervals(lambda _: False, intervals, mocker.Mock(), mocker.Mock()) == [] - ) + assert check_ready_intervals(lambda _: False, intervals, mocker.Mock(), mocker.Mock()) == [] assert_never_signal([]) assert_never_signal([(0, 1)]) @@ -2560,7 +2558,7 @@ def assert_never_signal(intervals): assert_never_signal([(0, 1), (2, 3)]) def assert_empty_signal(intervals): - assert _check_ready_intervals(lambda _: [], intervals, mocker.Mock(), mocker.Mock()) == [] + assert check_ready_intervals(lambda _: [], intervals, mocker.Mock(), mocker.Mock()) == [] assert_empty_signal([]) assert_empty_signal([(0, 1)]) @@ -2577,7 +2575,7 @@ def assert_check_intervals( ): mock = mocker.Mock() mock.side_effect = [to_intervals(r) for r in ready] - _check_ready_intervals(mock, intervals, mocker.Mock(), mocker.Mock()) == expected + check_ready_intervals(mock, intervals, mocker.Mock(), mocker.Mock()) == expected assert_check_intervals([], [], []) assert_check_intervals([(0, 1)], [[]], []) @@ -2618,7 +2616,7 @@ def assert_check_intervals( ) with pytest.raises(SignalEvalError): - _check_ready_intervals( + check_ready_intervals( lambda _: (_ for _ in ()).throw(MemoryError("Some exception")), [(0, 1), (1, 2)], mocker.Mock(),