Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions sqlglot/expressions/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,15 @@


if t.TYPE_CHECKING:
from collections.abc import Sequence, Iterable
from collections.abc import Sequence, Iterable, Iterator
from sqlglot.dialects.dialect import DialectType
from sqlglot.expressions.core import ExpOrStr, Func
from sqlglot.expressions.datatypes import DATA_TYPE
from sqlglot._typing import ParserArgs, ParserNoDialectArgs, E
from typing_extensions import Unpack
from typing_extensions import Unpack, ParamSpec, Concatenate
from sqlglot.expressions.core import Dot

P = ParamSpec("P")


def select(
Expand Down Expand Up @@ -115,7 +118,7 @@ def from_(

def update(
table: str | Table,
properties: t.Optional[dict] = None,
properties: t.Optional[dict[str, object]] = None,
where: t.Optional[ExpOrStr] = None,
from_: t.Optional[ExpOrStr] = None,
with_: t.Optional[dict[str, ExpOrStr]] = None,
Expand Down Expand Up @@ -348,7 +351,7 @@ def to_interval(interval: str | Expr) -> Interval:


def to_table(
sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs
sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs: object
) -> Table:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
Expand Down Expand Up @@ -376,19 +379,16 @@ def to_table(

table = table_(this, db=db, catalog=catalog)

for k, v in kwargs.items():
table.set(k, v)

return table
return table.set_kwargs(kwargs)


def to_column(
sql_path: str | Column,
quoted: t.Optional[bool] = None,
dialect: DialectType = None,
copy: bool = True,
**kwargs,
) -> Column:
**kwargs: t.Any,
) -> t.Union[Column, Dot]:
"""
Create a column from a `[table].[column]` sql path. Table is optional.
If a column is passed in then that column is returned.
Expand Down Expand Up @@ -648,7 +648,12 @@ def rename_column(
)


def replace_children(expression: Expr, fun: t.Callable, *args, **kwargs) -> None:
def replace_children(
expression: Expr,
fun: t.Callable[Concatenate[Expr, P], object],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
Expand All @@ -673,7 +678,7 @@ def replace_children(expression: Expr, fun: t.Callable, *args, **kwargs) -> None

def replace_tree(
expression: Expr,
fun: t.Callable,
fun: t.Callable[[Expr], Expr],
prune: t.Optional[t.Callable[[Expr], bool]] = None,
) -> Expr:
"""
Expand All @@ -697,7 +702,7 @@ def replace_tree(
return new_node


def find_tables(expression: Expr) -> t.Set[Table]:
def find_tables(expression: Expr) -> set[Table]:
"""
Find all tables referenced in a query.

Expand All @@ -717,7 +722,7 @@ def find_tables(expression: Expr) -> t.Set[Table]:
}


def column_table_names(expression: Expr, exclude: str = "") -> t.Set[str]:
def column_table_names(expression: Expr, exclude: str = "") -> set[str]:
"""
Return all table names referenced through columns in an expression.

Expand Down Expand Up @@ -797,7 +802,7 @@ def normalize_table_name(table: str | Table, dialect: DialectType = None, copy:


def replace_tables(
expression: E, mapping: t.Dict[str, str], dialect: DialectType = None, copy: bool = True
expression: E, mapping: dict[str, str], dialect: DialectType = None, copy: bool = True
) -> E:
"""Replace all tables in expression according to the mapping.

Expand Down Expand Up @@ -836,7 +841,7 @@ def _replace_tables(node: Expr) -> Expr:
return expression.transform(_replace_tables, copy=copy) # type: ignore


def replace_placeholders(expression: Expr, *args, **kwargs) -> Expr:
def replace_placeholders(expression: Expr, *args: object, **kwargs: t.Any) -> Expr:
"""Replace placeholders in an expression.

Args:
Expand All @@ -856,7 +861,7 @@ def replace_placeholders(expression: Expr, *args, **kwargs) -> Expr:
The mapped expression.
"""

def _replace_placeholders(node: Expr, args, **kwargs) -> Expr:
def _replace_placeholders(node: Expr, args: Iterator[object], **kwargs: object) -> Expr:
if isinstance(node, Placeholder):
if node.this:
new_name = kwargs.get(node.this)
Expand All @@ -874,7 +879,7 @@ def _replace_placeholders(node: Expr, args, **kwargs) -> Expr:

def expand(
expression: Expr,
sources: t.Dict[str, Query | t.Callable[[], Query]],
sources: dict[str, Query | t.Callable[[], Query]],
dialect: DialectType = None,
copy: bool = True,
) -> Expr:
Expand Down Expand Up @@ -918,7 +923,9 @@ def _expand(node: Expr):
return expression.transform(_expand, copy=copy)


def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs) -> Func:
def func(
name: str, *args: t.Any, copy: bool = True, dialect: DialectType = None, **kwargs: t.Any
) -> Func:
"""
Returns a Func expression.

Expand Down Expand Up @@ -950,7 +957,7 @@ def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwa

dialect = Dialect.get_or_raise(dialect)

converted: t.List[Expr] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args]
converted: list[Expr] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args]
kwargs = {key: maybe_parse(value, dialect=dialect, copy=copy) for key, value in kwargs.items()}

constructor = dialect.parser_class.FUNCTIONS.get(name.upper())
Expand Down Expand Up @@ -1007,7 +1014,10 @@ def case(


def array(
*expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
*expressions: ExpOrStr,
copy: bool = True,
dialect: DialectType = None,
**kwargs: Unpack[ParserNoDialectArgs],
) -> Array:
"""
Returns an array.
Expand All @@ -1034,7 +1044,10 @@ def array(


def tuple_(
*expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
*expressions: ExpOrStr,
copy: bool = True,
dialect: DialectType = None,
**kwargs: Unpack[ParserNoDialectArgs],
) -> Tuple:
"""
Returns an tuple.
Expand Down Expand Up @@ -1083,10 +1096,10 @@ def null() -> Null:

def apply_index_offset(
this: Expr,
expressions: t.List[E],
expressions: list[E],
offset: int,
dialect: DialectType = None,
) -> t.List[E]:
) -> list[E]:
if not offset or len(expressions) != 1:
return expressions

Expand Down
Loading