diff --git a/sqlglot/expressions/builders.py b/sqlglot/expressions/builders.py index 2f44e0aa69..f4f0fa2df0 100644 --- a/sqlglot/expressions/builders.py +++ b/sqlglot/expressions/builders.py @@ -51,12 +51,15 @@ if t.TYPE_CHECKING: - from collections.abc import Sequence, Iterable + from collections.abc import Sequence, Iterable, Iterator from sqlglot.dialects.dialect import DialectType from sqlglot.expressions.core import ExpOrStr, Func from sqlglot.expressions.datatypes import DATA_TYPE from sqlglot._typing import ParserArgs, ParserNoDialectArgs, E - from typing_extensions import Unpack + from typing_extensions import Unpack, ParamSpec, Concatenate + from sqlglot.expressions.core import Dot + + P = ParamSpec("P") def select( @@ -115,7 +118,7 @@ def from_( def update( table: str | Table, - properties: t.Optional[dict] = None, + properties: t.Optional[dict[str, object]] = None, where: t.Optional[ExpOrStr] = None, from_: t.Optional[ExpOrStr] = None, with_: t.Optional[dict[str, ExpOrStr]] = None, @@ -348,7 +351,7 @@ def to_interval(interval: str | Expr) -> Interval: def to_table( - sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs + sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs: object ) -> Table: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. @@ -376,10 +379,7 @@ def to_table( table = table_(this, db=db, catalog=catalog) - for k, v in kwargs.items(): - table.set(k, v) - - return table + return table.set_kwargs(kwargs) def to_column( @@ -387,8 +387,8 @@ def to_column( quoted: t.Optional[bool] = None, dialect: DialectType = None, copy: bool = True, - **kwargs, -) -> Column: + **kwargs: t.Any, +) -> t.Union[Column, Dot]: """ Create a column from a `[table].[column]` sql path. Table is optional. If a column is passed in then that column is returned. @@ -648,7 +648,12 @@ def rename_column( ) -def replace_children(expression: Expr, fun: t.Callable, *args, **kwargs) -> None: +def replace_children( + expression: Expr, + fun: t.Callable[Concatenate[Expr, P], object], + *args: P.args, + **kwargs: P.kwargs, +) -> None: """ Replace children of an expression with the result of a lambda fun(child) -> exp. """ @@ -673,7 +678,7 @@ def replace_children(expression: Expr, fun: t.Callable, *args, **kwargs) -> None def replace_tree( expression: Expr, - fun: t.Callable, + fun: t.Callable[[Expr], Expr], prune: t.Optional[t.Callable[[Expr], bool]] = None, ) -> Expr: """ @@ -697,7 +702,7 @@ def replace_tree( return new_node -def find_tables(expression: Expr) -> t.Set[Table]: +def find_tables(expression: Expr) -> set[Table]: """ Find all tables referenced in a query. @@ -717,7 +722,7 @@ def find_tables(expression: Expr) -> t.Set[Table]: } -def column_table_names(expression: Expr, exclude: str = "") -> t.Set[str]: +def column_table_names(expression: Expr, exclude: str = "") -> set[str]: """ Return all table names referenced through columns in an expression. @@ -797,7 +802,7 @@ def normalize_table_name(table: str | Table, dialect: DialectType = None, copy: def replace_tables( - expression: E, mapping: t.Dict[str, str], dialect: DialectType = None, copy: bool = True + expression: E, mapping: dict[str, str], dialect: DialectType = None, copy: bool = True ) -> E: """Replace all tables in expression according to the mapping. @@ -836,7 +841,7 @@ def _replace_tables(node: Expr) -> Expr: return expression.transform(_replace_tables, copy=copy) # type: ignore -def replace_placeholders(expression: Expr, *args, **kwargs) -> Expr: +def replace_placeholders(expression: Expr, *args: object, **kwargs: t.Any) -> Expr: """Replace placeholders in an expression. Args: @@ -856,7 +861,7 @@ def replace_placeholders(expression: Expr, *args, **kwargs) -> Expr: The mapped expression. """ - def _replace_placeholders(node: Expr, args, **kwargs) -> Expr: + def _replace_placeholders(node: Expr, args: Iterator[object], **kwargs: object) -> Expr: if isinstance(node, Placeholder): if node.this: new_name = kwargs.get(node.this) @@ -874,7 +879,7 @@ def _replace_placeholders(node: Expr, args, **kwargs) -> Expr: def expand( expression: Expr, - sources: t.Dict[str, Query | t.Callable[[], Query]], + sources: dict[str, Query | t.Callable[[], Query]], dialect: DialectType = None, copy: bool = True, ) -> Expr: @@ -918,7 +923,9 @@ def _expand(node: Expr): return expression.transform(_expand, copy=copy) -def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs) -> Func: +def func( + name: str, *args: t.Any, copy: bool = True, dialect: DialectType = None, **kwargs: t.Any +) -> Func: """ Returns a Func expression. @@ -950,7 +957,7 @@ def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwa dialect = Dialect.get_or_raise(dialect) - converted: t.List[Expr] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args] + converted: list[Expr] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args] kwargs = {key: maybe_parse(value, dialect=dialect, copy=copy) for key, value in kwargs.items()} constructor = dialect.parser_class.FUNCTIONS.get(name.upper()) @@ -1007,7 +1014,10 @@ def case( def array( - *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs + *expressions: ExpOrStr, + copy: bool = True, + dialect: DialectType = None, + **kwargs: Unpack[ParserNoDialectArgs], ) -> Array: """ Returns an array. @@ -1034,7 +1044,10 @@ def array( def tuple_( - *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs + *expressions: ExpOrStr, + copy: bool = True, + dialect: DialectType = None, + **kwargs: Unpack[ParserNoDialectArgs], ) -> Tuple: """ Returns an tuple. @@ -1083,10 +1096,10 @@ def null() -> Null: def apply_index_offset( this: Expr, - expressions: t.List[E], + expressions: list[E], offset: int, dialect: DialectType = None, -) -> t.List[E]: +) -> list[E]: if not offset or len(expressions) != 1: return expressions