diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index f6afd5eea2..87c7ade710 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -25,10 +25,7 @@ BinaryCoercionFunc = t.Callable[ [exp.Expr, exp.Expr], t.Optional[t.Union[exp.DataType, exp.DType]] ] - BinaryCoercions = t.Dict[ - t.Tuple[exp.DType, exp.DType], - BinaryCoercionFunc, - ] + BinaryCoercions = dict[tuple[exp.DType, exp.DType], BinaryCoercionFunc] from sqlglot.dialects.dialect import DialectType from sqlglot.typing import ExprMetadataType @@ -47,9 +44,9 @@ def annotate_types( expression: E, - schema: t.Optional[t.Dict | Schema] = None, + schema: dict[str, object] | Schema | None = None, expression_metadata: t.Optional[ExprMetadataType] = None, - coerces_to: t.Optional[t.Dict[exp.DType, t.Set[exp.DType]]] = None, + coerces_to: dict[exp.DType, set[exp.DType]] | None = None, dialect: DialectType = None, overwrite_types: bool = True, ) -> E: diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index e0c1d9c1d2..62ad8ed54a 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -15,7 +15,7 @@ def isolate_table_selects( expression: E, - schema: t.Optional[t.Dict | Schema] = None, + schema: dict[str, object] | Schema | None = None, dialect: DialectType = None, ) -> E: schema = ensure_schema(schema, dialect=dialect) diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index e7c18665aa..c7d3ea0c53 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,6 +1,5 @@ from __future__ import annotations - -import typing as t +from collections.abc import Iterable from sqlglot import exp from sqlglot.helper import tsort @@ -8,7 +7,7 @@ JOIN_ATTRS = ("on", "side", "kind", "using", "method") -def optimize_joins(expression): +def optimize_joins(expression: exp.Expr) -> exp.Expr: """ Removes cross joins if possible and reorder joins based on predicate dependencies. @@ -24,8 +23,8 @@ def optimize_joins(expression): if not _is_reorderable(joins): continue - references = {} - cross_joins = [] + references: dict[str, list[exp.Join]] = {} + cross_joins: list[tuple[str, exp.Join]] = [] for join in joins: tables = other_table_names(join) @@ -58,7 +57,7 @@ def optimize_joins(expression): return expression -def reorder_joins(expression): +def reorder_joins(expression) -> exp.Expr: """ Reorder joins by topological sort order based on predicate references. """ @@ -82,7 +81,7 @@ def reorder_joins(expression): return expression -def normalize(expression): +def normalize(expression: exp.Expr) -> exp.Expr: """ Remove INNER and OUTER from joins as they are optional. """ @@ -101,12 +100,12 @@ def normalize(expression): return expression -def other_table_names(join: exp.Join) -> t.Set[str]: +def other_table_names(join: exp.Join) -> set[str]: on = join.args.get("on") return exp.column_table_names(on, join.alias_or_name) if on else set() -def _is_reorderable(joins: t.List[exp.Join]) -> bool: +def _is_reorderable(joins: Iterable[exp.Join]) -> bool: """ Checks if joins can be reordered without changing query semantics. diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index f99f4aa412..f297ae0615 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -21,7 +21,23 @@ from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema -RULES = ( + +class OptimizerFn(t.Protocol): + """Protocol for optimizer rules functions. + + An optimizer rule: + + - **Must** accept an `Expr` as the first argument + - Can take undefined `*args` and `**kwargs` afterwards + - **Must** return an `Expr`. + Note: + We use `typing.Protocol` here because this is not doable with `collections.abc.Callable`. + """ + + def __call__(self, expression: exp.Expr, *args: t.Any, **kwargs: t.Any) -> exp.Expr: ... + + +RULES: tuple[OptimizerFn, ...] = ( qualify, pushdown_projections, normalize, @@ -41,13 +57,13 @@ def optimize( expression: str | exp.Expr, - schema: t.Optional[dict | Schema] = None, - db: t.Optional[str | exp.Identifier] = None, - catalog: t.Optional[str | exp.Identifier] = None, + schema: dict[str, object] | Schema | None = None, + db: str | exp.Identifier | None = None, + catalog: str | exp.Identifier | None = None, dialect: DialectType = None, - rules: Sequence[t.Callable] = RULES, + rules: Sequence[OptimizerFn] = RULES, sql: t.Optional[str] = None, - **kwargs, + **kwargs: object, ) -> exp.Expr: """ Rewrite a sqlglot AST into an optimized form. diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 16f7b313c5..157f7ec395 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -26,7 +26,7 @@ def default_selection(is_agg: bool) -> exp.Alias: def pushdown_projections( expression: E, - schema: t.Optional[t.Dict | Schema] = None, + schema: dict[str, object] | Schema | None = None, remove_unused_selections: bool = True, dialect: DialectType = None, ) -> E: diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index 8d889696a7..63dcb39284 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -18,12 +18,12 @@ def qualify( expression: exp.Expr, dialect: DialectType = None, - db: t.Optional[str] = None, - catalog: t.Optional[str] = None, - schema: t.Optional[dict | Schema] = None, + db: str | None = None, + catalog: str | None = None, + schema: dict[str, object] | Schema | None = None, expand_alias_refs: bool = True, expand_stars: bool = True, - infer_schema: t.Optional[bool] = None, + infer_schema: bool | None = None, isolate_tables: bool = False, qualify_columns: bool = True, allow_partial_qualification: bool = False, @@ -31,8 +31,8 @@ def qualify( quote_identifiers: bool = True, identify: bool = True, canonicalize_table_aliases: bool = False, - on_qualify: t.Optional[t.Callable[[exp.Expr], None]] = None, - sql: t.Optional[str] = None, + on_qualify: t.Callable[[exp.Expr], None] | None = None, + sql: str | None = None, ) -> exp.Expr: """ Rewrite sqlglot AST to have normalized and qualified tables and columns. diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 93075300d3..13ddd57067 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -20,7 +20,7 @@ def qualify_columns( expression: exp.Expr, - schema: dict | Schema, + schema: dict[str, object] | Schema, expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: t.Optional[bool] = None, diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6f9db10bf0..ae6bdb8ac3 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1,6 +1,6 @@ from __future__ import annotations -import datetime +from datetime import date, datetime, timedelta import logging import functools import itertools @@ -15,15 +15,17 @@ from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope from sqlglot.schema import ensure_schema + if t.TYPE_CHECKING: + from dateutil.relativedelta import relativedelta from sqlglot.dialects.dialect import DialectType + from typing_extensions import TypeIs - DateRange = t.Tuple[datetime.date, datetime.date] + DateRange = tuple[date, date] DateTruncBinaryTransform = t.Callable[ - [exp.Expr, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expr] + [exp.Expr, date, str, Dialect, exp.DataType], t.Optional[exp.Expr] ] - logger = logging.getLogger("sqlglot") @@ -115,7 +117,7 @@ def _func(self, expression: exp.Expr, *args, **kwargs) -> t.Optional[exp.Expr]: return _func -def flatten(expression): +def flatten(expression: exp.Expr) -> exp.Expr: """ A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C @@ -230,7 +232,7 @@ def _is_constant(expression: exp.Expr) -> bool: return isinstance(expr, exp.CONSTANTS) or _is_date_literal(expr) -def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: +def _datetrunc_range(date: date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: """ Get the date range for a DATE_TRUNC equality comparison: @@ -261,7 +263,7 @@ def _datetrunc_eq_expression( def _datetrunc_eq( left: exp.Expr, - date: datetime.date, + date: date, unit: str, dialect: Dialect, target_type: t.Optional[exp.DataType], @@ -275,7 +277,7 @@ def _datetrunc_eq( def _datetrunc_neq( left: exp.Expr, - date: datetime.date, + date: date, unit: str, dialect: Dialect, target_type: t.Optional[exp.DataType], @@ -291,33 +293,44 @@ def _datetrunc_neq( ) -def always_true(expression): +def always_true(expression: object) -> bool: return (isinstance(expression, exp.Boolean) and expression.this) or ( isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression) ) -def always_false(expression): +def always_false(expression: object) -> bool: return is_false(expression) or is_null(expression) or is_zero(expression) -def is_zero(expression): +def is_zero(expression: object) -> bool: return isinstance(expression, exp.Literal) and expression.to_py() == 0 -def is_complement(a, b): +def is_complement(a: object, b: object) -> bool: return isinstance(b, exp.Not) and b.this == a -def is_false(a: exp.Expr) -> bool: +def is_false(a: object) -> bool: return type(a) is exp.Boolean and not a.this -def is_null(a: exp.Expr) -> bool: +def is_null(a: object) -> bool: return type(a) is exp.Null -def eval_boolean(expression, a, b): +class SupportsComparison(t.Protocol): + """Protocol for expressions or values that can be compared using <, <=, >, >=.""" + + def __lt__(self, other: t.Any) -> bool: ... + def __le__(self, other: t.Any) -> bool: ... + def __gt__(self, other: t.Any) -> bool: ... + def __ge__(self, other: t.Any) -> bool: ... + + +def eval_boolean( + expression: object, a: SupportsComparison, b: SupportsComparison +) -> exp.Boolean | None: if isinstance(expression, (exp.EQ, exp.Is)): return boolean_literal(a == b) if isinstance(expression, exp.NEQ): @@ -333,29 +346,31 @@ def eval_boolean(expression, a, b): return None -def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: - if isinstance(value, datetime.datetime): +def cast_as_date(value: datetime | date | str) -> date | None: + if isinstance(value, datetime): return value.date() - if isinstance(value, datetime.date): + if isinstance(value, date): return value try: - return datetime.datetime.fromisoformat(value).date() + return datetime.fromisoformat(value).date() except ValueError: return None -def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: - if isinstance(value, datetime.datetime): +def cast_as_datetime( + value: datetime | date | str, +) -> t.Optional[datetime]: + if isinstance(value, datetime): return value - if isinstance(value, datetime.date): - return datetime.datetime(year=value.year, month=value.month, day=value.day) + if isinstance(value, date): + return datetime(year=value.year, month=value.month, day=value.day) try: - return datetime.datetime.fromisoformat(value) + return datetime.fromisoformat(value) except ValueError: return None -def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: +def cast_value(value: datetime | date | str, to: exp.DataType) -> date | date | None: if not value: return None if to.is_type(exp.DType.DATE): @@ -365,7 +380,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da return None -def extract_date(cast: exp.Expr) -> t.Optional[t.Union[datetime.date, datetime.date]]: +def extract_date(cast: exp.Expr) -> date | date | None: if isinstance(cast, exp.Cast): to = cast.to elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): @@ -386,7 +401,7 @@ def _is_date_literal(expression: exp.Expr) -> bool: return extract_date(expression) is not None -def extract_interval(expression): +def extract_interval(expression: exp.Expr) -> relativedelta | None: try: n = int(expression.this.to_py()) unit = expression.text("unit").lower() @@ -395,7 +410,7 @@ def extract_interval(expression): return None -def extract_type(*expressions): +def extract_type(*expressions: exp.Expr): target_type = None for expression in expressions: target_type = expression.to if isinstance(expression, exp.Cast) else expression.type @@ -405,14 +420,14 @@ def extract_type(*expressions): return target_type -def date_literal(date, target_type=None): +def date_literal(date: object, target_type=None) -> exp.Expr: if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): - target_type = exp.DType.DATETIME if isinstance(date, datetime.datetime) else exp.DType.DATE + target_type = exp.DType.DATETIME if isinstance(date, datetime) else exp.DType.DATE return exp.cast(exp.Literal.string(date), target_type) -def interval(unit: str, n: int = 1): +def interval(unit: str, n: int = 1) -> relativedelta: from dateutil.relativedelta import relativedelta if unit == "year": @@ -435,7 +450,7 @@ def interval(unit: str, n: int = 1): raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: +def date_floor(d: date, unit: str, dialect: Dialect) -> date: if unit == "year": return d.replace(month=1, day=1) if unit == "quarter": @@ -451,14 +466,14 @@ def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: return d.replace(month=d.month, day=1) if unit == "week": # Assuming week starts on Monday (0) and ends on Sunday (6) - return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) + return d - timedelta(days=d.weekday() - dialect.WEEK_OFFSET) if unit == "day": return d raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: +def date_ceil(d: date, unit: str, dialect: Dialect) -> date: floor = date_floor(d, unit, dialect) if floor == d: @@ -467,7 +482,7 @@ def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: return floor + interval(unit) -def boolean_literal(condition): +def boolean_literal(condition: bool | exp.Predicate) -> exp.Boolean: return exp.true() if condition else exp.false() @@ -572,9 +587,9 @@ def simplify( expression: exp.Expr, constant_propagation: bool = False, coalesce_simplification: bool = False, - ): - wheres = [] - joins = [] + ) -> exp.Expr: + wheres: list[exp.Where] = [] + joins: list[exp.Join] = [] for node in expression.walk( prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL)) @@ -781,8 +796,8 @@ def simplify_not(self, expression: exp.Expr) -> exp.Expr: return expression @annotate_types_on_change - def simplify_connectors(self, expression, root=True): - def _simplify_connectors(expression, left, right): + def simplify_connectors(self, expression: exp.Expr, root: bool = True) -> exp.Expr: + def _simplify_connectors(expression: exp.Expr, left: exp.Expr, right: exp.Expr): if isinstance(expression, exp.And): if is_false(left) or is_false(right): return exp.false() @@ -841,7 +856,9 @@ def _simplify_connectors(expression, left, right): return expression @annotate_types_on_change - def _simplify_comparison(self, expression, left, right, or_=False): + def _simplify_comparison( + self, expression: exp.Expr, left: exp.Expr, right: exp.Expr, or_: bool = False + ) -> exp.Expr | None: if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS): ll, lr = left.args.values() rl, rr = right.args.values() @@ -905,7 +922,7 @@ def _simplify_comparison(self, expression, left, right, or_=False): return None @annotate_types_on_change - def remove_complements(self, expression, root=True): + def remove_complements(self, expression: object, root: bool = True) -> object: """ Removing complements. @@ -922,7 +939,7 @@ def remove_complements(self, expression, root=True): return expression @annotate_types_on_change - def uniq_sort(self, expression, root=True): + def uniq_sort(self, expression: object, root: bool = True) -> object: """ Uniq and sort a connector. @@ -959,7 +976,7 @@ def uniq_sort(self, expression, root=True): return expression @annotate_types_on_change - def absorb_and_eliminate(self, expression, root=True): + def absorb_and_eliminate(self, expression: object, root: bool = True) -> object: """ absorption: A AND (A OR B) -> A @@ -1080,15 +1097,18 @@ def simplify_equality(self, expression: exp.Expr) -> exp.Expr: ) return expression + def _is_inverse_date_op(self, expression: exp.Expr) -> TypeIs[exp.IntervalOp]: + return type(expression) in self.INVERSE_DATE_OPS + @annotate_types_on_change - def simplify_literals(self, expression, root=True): + def simplify_literals(self, expression: exp.Expr, root: bool = True) -> exp.Expr: if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): return self._flat_simplify(expression, self._simplify_binary, root) if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): return expression.this.this - if type(expression) in self.INVERSE_DATE_OPS: + if self._is_inverse_date_op(expression): return ( self._simplify_binary(expression, expression.this, expression.interval()) or expression @@ -1436,7 +1456,12 @@ def sort_comparison(self, expression: exp.Expr) -> exp.Expr: ) return expression - def _flat_simplify(self, expression, simplifier, root=True): + def _flat_simplify( + self, + expression: exp.Expr, + simplifier: t.Callable[[exp.Expr, exp.Expr, exp.Expr], exp.Expr | None], + root: bool = True, + ) -> exp.Expr: if root or not expression.same_parent: operands = [] queue = deque(expression.flatten(unnest=False)) @@ -1462,7 +1487,7 @@ def _flat_simplify(self, expression, simplifier, root=True): return expression -def gen(expression: t.Any, comments: bool = False) -> str: +def gen(expression: exp.Expr, comments: bool = False) -> str: """Simple pseudo sql generator for quickly generating sortable and uniq strings. Sorting and deduping sql is a necessary step for optimization. Calling the actual diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index a56ad06db9..baf6dd37da 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -3,7 +3,7 @@ from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope -def unnest_subqueries(expression): +def unnest_subqueries(expression: exp.Expr) -> exp.Expr: """ Rewrite sqlglot AST to convert some predicates with subqueries into joins.