3434 get_dialect ,
3535 validate_string ,
3636 positive_int_validator ,
37+ validate_expression ,
3738)
3839
3940
@@ -467,15 +468,20 @@ def _when_matched_validator(
467468 return v
468469 if isinstance (v , list ):
469470 v = " " .join (v )
471+
472+ dialect = get_dialect (info .data )
473+
470474 if isinstance (v , str ):
471475 # Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot
472476 v = v .strip ()
473477 if v .startswith ("(" ):
474478 v = v [1 :- 1 ]
475479
476- return t .cast (exp .Whens , d .parse_one (v , into = exp .Whens , dialect = get_dialect (info .data )))
480+ v = t .cast (exp .Whens , d .parse_one (v , into = exp .Whens , dialect = dialect ))
481+ else :
482+ v = t .cast (exp .Whens , v .transform (d .replace_merge_table_aliases , dialect = dialect ))
477483
478- return t . cast ( exp . Whens , v . transform ( d . replace_merge_table_aliases ) )
484+ return validate_expression ( v , dialect = dialect )
479485
480486 @field_validator ("merge_filter" , mode = "before" )
481487 def _merge_filter_validator (
@@ -485,11 +491,16 @@ def _merge_filter_validator(
485491 ) -> t .Optional [exp .Expression ]:
486492 if v is None :
487493 return v
494+
495+ dialect = get_dialect (info .data )
496+
488497 if isinstance (v , str ):
489498 v = v .strip ()
490- return d .parse_one (v , dialect = get_dialect (info .data ))
499+ v = d .parse_one (v , dialect = dialect )
500+ else :
501+ v = v .transform (d .replace_merge_table_aliases , dialect = dialect )
491502
492- return v . transform ( d . replace_merge_table_aliases )
503+ return validate_expression ( v , dialect = dialect )
493504
494505 @property
495506 def data_hash_values (self ) -> t .List [t .Optional [str ]]:
0 commit comments