Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
65340ab
refactor (optimizer): add various type annotations to public functions
OutSquareCapital Apr 3, 2026
11b05b8
refactor(optimizer): add various type annotations to simplify module
OutSquareCapital Apr 3, 2026
6041213
refactor(optimizer): improve documnetation for `OptimizerFn` Protocol
OutSquareCapital Apr 3, 2026
91159c5
Chore: ran ruff format
OutSquareCapital Apr 3, 2026
3c2eab2
refactor (optimizer): going back to original import pattern for `norm…
OutSquareCapital Apr 3, 2026
5d4d202
refactor (optimizer): make `_simplify_integer_cast` generic, optimize…
OutSquareCapital Apr 3, 2026
dfa5c17
fix: make `S` TypeVar a runtime concrete value for mypc
OutSquareCapital Apr 3, 2026
869f0a3
fix: for some reason mypc don't work well with generics for `_simplif…
OutSquareCapital Apr 3, 2026
0614a97
fix: revert `_simplify_integer_cast` to original typing to avoid mypc…
OutSquareCapital Apr 3, 2026
b967575
Fix: Use a Protocol for `eval_boolean`, since it can take any compara…
OutSquareCapital Apr 3, 2026
fb147b4
Chore: run formatter
OutSquareCapital Apr 3, 2026
cb6c323
Refactor (optimizer): widen some helper functions input types
OutSquareCapital Apr 3, 2026
82c9193
refactor (optimizer): use `TypeIs` in `simplify_literals` for type na…
OutSquareCapital Apr 3, 2026
08fbb54
chore: ruff format
OutSquareCapital Apr 3, 2026
d3da9d8
fix: added various return types to `simplify` functions
OutSquareCapital Apr 3, 2026
2e7dc46
chore: ruff format
OutSquareCapital Apr 3, 2026
aff8bfe
fix: annotate `_simplify_comparison` return type
OutSquareCapital Apr 3, 2026
4850911
refactor: add type hints to `extract_interval` and `date_literal`
OutSquareCapital Apr 4, 2026
6ddb177
Merge branch 'main' into optimizer-annotations
OutSquareCapital Apr 4, 2026
4209d1c
fix: widen return type of `date_literal` to satisfy mypc
OutSquareCapital Apr 4, 2026
9c2717f
Merge branch 'main' into optimizer-annotations
OutSquareCapital Apr 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
BinaryCoercionFunc = t.Callable[
[exp.Expr, exp.Expr], t.Optional[t.Union[exp.DataType, exp.DType]]
]
BinaryCoercions = t.Dict[
t.Tuple[exp.DType, exp.DType],
BinaryCoercionFunc,
]
BinaryCoercions = dict[tuple[exp.DType, exp.DType], BinaryCoercionFunc]

from sqlglot.dialects.dialect import DialectType
from sqlglot.typing import ExprMetadataType
Expand All @@ -47,9 +44,9 @@

def annotate_types(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
schema: dict[str, object] | Schema | None = None,
expression_metadata: t.Optional[ExprMetadataType] = None,
coerces_to: t.Optional[t.Dict[exp.DType, t.Set[exp.DType]]] = None,
coerces_to: dict[exp.DType, set[exp.DType]] | None = None,
dialect: DialectType = None,
overwrite_types: bool = True,
) -> E:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/isolate_table_selects.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def isolate_table_selects(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
schema: dict[str, object] | Schema | None = None,
dialect: DialectType = None,
) -> E:
schema = ensure_schema(schema, dialect=dialect)
Expand Down
17 changes: 8 additions & 9 deletions sqlglot/optimizer/optimize_joins.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import typing as t
from collections.abc import Iterable

from sqlglot import exp
from sqlglot.helper import tsort

JOIN_ATTRS = ("on", "side", "kind", "using", "method")


def optimize_joins(expression):
def optimize_joins(expression: exp.Expr) -> exp.Expr:
"""
Removes cross joins if possible and reorder joins based on predicate dependencies.

Expand All @@ -24,8 +23,8 @@ def optimize_joins(expression):
if not _is_reorderable(joins):
continue

references = {}
cross_joins = []
references: dict[str, list[exp.Join]] = {}
cross_joins: list[tuple[str, exp.Join]] = []

for join in joins:
tables = other_table_names(join)
Expand Down Expand Up @@ -58,7 +57,7 @@ def optimize_joins(expression):
return expression


def reorder_joins(expression):
def reorder_joins(expression) -> exp.Expr:
"""
Reorder joins by topological sort order based on predicate references.
"""
Expand All @@ -82,7 +81,7 @@ def reorder_joins(expression):
return expression


def normalize(expression):
def normalize(expression: exp.Expr) -> exp.Expr:
"""
Remove INNER and OUTER from joins as they are optional.
"""
Expand All @@ -101,12 +100,12 @@ def normalize(expression):
return expression


def other_table_names(join: exp.Join) -> t.Set[str]:
def other_table_names(join: exp.Join) -> set[str]:
on = join.args.get("on")
return exp.column_table_names(on, join.alias_or_name) if on else set()


def _is_reorderable(joins: t.List[exp.Join]) -> bool:
def _is_reorderable(joins: Iterable[exp.Join]) -> bool:
"""
Checks if joins can be reordered without changing query semantics.

Expand Down
28 changes: 22 additions & 6 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,23 @@
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema

RULES = (

class OptimizerFn(t.Protocol):
"""Protocol for optimizer rules functions.

An optimizer rule:

- **Must** accept an `Expr` as the first argument
- Can take undefined `*args` and `**kwargs` afterwards
- **Must** return an `Expr`.
Note:
We use `typing.Protocol` here because this is not doable with `collections.abc.Callable`.
"""

def __call__(self, expression: exp.Expr, *args: t.Any, **kwargs: t.Any) -> exp.Expr: ...


RULES: tuple[OptimizerFn, ...] = (
qualify,
pushdown_projections,
normalize,
Expand All @@ -41,13 +57,13 @@

def optimize(
expression: str | exp.Expr,
schema: t.Optional[dict | Schema] = None,
db: t.Optional[str | exp.Identifier] = None,
catalog: t.Optional[str | exp.Identifier] = None,
schema: dict[str, object] | Schema | None = None,
db: str | exp.Identifier | None = None,
catalog: str | exp.Identifier | None = None,
dialect: DialectType = None,
rules: Sequence[t.Callable] = RULES,
rules: Sequence[OptimizerFn] = RULES,
sql: t.Optional[str] = None,
**kwargs,
**kwargs: object,
) -> exp.Expr:
"""
Rewrite a sqlglot AST into an optimized form.
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/pushdown_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def default_selection(is_agg: bool) -> exp.Alias:

def pushdown_projections(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
schema: dict[str, object] | Schema | None = None,
remove_unused_selections: bool = True,
dialect: DialectType = None,
) -> E:
Expand Down
12 changes: 6 additions & 6 deletions sqlglot/optimizer/qualify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
def qualify(
expression: exp.Expr,
dialect: DialectType = None,
db: t.Optional[str] = None,
catalog: t.Optional[str] = None,
schema: t.Optional[dict | Schema] = None,
db: str | None = None,
catalog: str | None = None,
schema: dict[str, object] | Schema | None = None,
expand_alias_refs: bool = True,
expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
infer_schema: bool | None = None,
isolate_tables: bool = False,
qualify_columns: bool = True,
allow_partial_qualification: bool = False,
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
canonicalize_table_aliases: bool = False,
on_qualify: t.Optional[t.Callable[[exp.Expr], None]] = None,
sql: t.Optional[str] = None,
on_qualify: t.Callable[[exp.Expr], None] | None = None,
sql: str | None = None,
) -> exp.Expr:
"""
Rewrite sqlglot AST to have normalized and qualified tables and columns.
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def qualify_columns(
expression: exp.Expr,
schema: dict | Schema,
schema: dict[str, object] | Schema,
expand_alias_refs: bool = True,
expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
Expand Down
Loading
Loading