Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sqlglot/_typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import typing as t
from collections.abc import Mapping, Sequence

if t.TYPE_CHECKING:
from typing_extensions import ParamSpec
from collections.abc import Mapping

import sqlglot
from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import ErrorLevel
Expand All @@ -16,6 +17,9 @@
F = t.TypeVar("F", bound="sqlglot.exp.Func")
T = t.TypeVar("T")

BuilderArgs = Sequence[t.Any]
"""Sequence of arguments passed to builder functions."""


class _DialectArg(t.TypedDict, total=False):
dialect: DialectType
Expand Down
64 changes: 33 additions & 31 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import typing as t
import sys
from collections.abc import Sequence
from collections.abc import Iterable, MutableSequence
from enum import Enum, auto
from functools import reduce
from builtins import type as Type
Expand Down Expand Up @@ -58,7 +58,7 @@
DATETIME_ADD = (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.TimestampAdd)

if t.TYPE_CHECKING:
from sqlglot._typing import B, E, F, GeneratorArgs, ParserArgs
from sqlglot._typing import B, E, F, GeneratorArgs, ParserArgs, BuilderArgs
from typing_extensions import Unpack

logger = logging.getLogger("sqlglot")
Expand Down Expand Up @@ -1398,7 +1398,7 @@ def array_concat_sql(
Dialects that propagate NULLs need to set `ARRAY_FUNCS_PROPAGATES_NULLS` to True.
"""

def _build_func_call(self: Generator, func_name: str, args: Sequence[exp.Expr]) -> str:
def _build_func_call(self: Generator, func_name: str, args: BuilderArgs) -> str:
"""Build ARRAY_CONCAT call from a list of arguments, handling variadic vs binary nesting."""
if self.ARRAY_CONCAT_IS_VAR_LEN:
return self.func(func_name, *args)
Expand Down Expand Up @@ -1535,8 +1535,8 @@ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str:


def build_formatted_time(
exp_class: Type[E], dialect: str, default: t.Optional[bool | str] = None
) -> t.Callable[[list], E]:
exp_class: Type[E], dialect: str, default: bool | str | None = None
) -> t.Callable[[BuilderArgs], E]:
"""Helper used for time expressions.

Args:
Expand All @@ -1548,7 +1548,7 @@ def build_formatted_time(
A callable that can be used to return the appropriately formatted time expression.
"""

def _builder(args: t.List):
def _builder(args: BuilderArgs) -> E:
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
Expand Down Expand Up @@ -1576,11 +1576,11 @@ def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) ->

def build_date_delta(
exp_class: Type[E],
unit_mapping: t.Optional[dict[str, str]] = None,
default_unit: t.Optional[str] = "DAY",
unit_mapping: dict[str, str] | None = None,
default_unit: str | None = "DAY",
supports_timezone: bool = False,
) -> t.Callable[[list], E]:
def _builder(args: list) -> E:
) -> t.Callable[[BuilderArgs], E]:
def _builder(args: BuilderArgs) -> E:
unit_based = len(args) >= 3
has_timezone = len(args) == 4
this = args[2] if unit_based else seq_get(args, 0)
Expand All @@ -1598,8 +1598,8 @@ def _builder(args: list) -> E:

def build_date_delta_with_interval(
expression_class: Type[E],
) -> t.Callable[[list], t.Optional[E]]:
def _builder(args: list) -> t.Optional[E]:
) -> t.Callable[[BuilderArgs], t.Optional[E]]:
def _builder(args: BuilderArgs) -> t.Optional[E]:
if len(args) < 2:
return None

Expand All @@ -1613,7 +1613,7 @@ def _builder(args: list) -> t.Optional[E]:
return _builder


def date_trunc_to_time(args: list) -> exp.DateTrunc | exp.TimestampTrunc:
def date_trunc_to_time(args: BuilderArgs) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)

Expand Down Expand Up @@ -1808,8 +1808,8 @@ def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
)


def pivot_column_names(aggregations: t.List[exp.Expr], dialect: DialectType) -> t.List[str]:
names = []
def pivot_column_names(aggregations: Iterable[exp.Expr], dialect: DialectType) -> list[str]:
names: list[str] = []
for agg in aggregations:
if isinstance(agg, exp.Alias):
names.append(agg.alias)
Expand All @@ -1832,17 +1832,17 @@ def pivot_column_names(aggregations: t.List[exp.Expr], dialect: DialectType) ->
return names


def binary_from_function(expr_type: Type[B]) -> t.Callable[[list], B]:
def binary_from_function(expr_type: Type[B]) -> t.Callable[[BuilderArgs], B]:
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))


# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
def build_timestamp_trunc(args: list) -> exp.TimestampTrunc:
def build_timestamp_trunc(args: BuilderArgs) -> exp.TimestampTrunc:
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))


def build_trunc(
args: t.List,
args: BuilderArgs,
dialect: DialectType,
date_trunc_unabbreviate: bool = True,
default_date_trunc_unit: t.Optional[str] = None,
Expand Down Expand Up @@ -1898,7 +1898,7 @@ def is_parse_json(expression: exp.Expr) -> bool:
)


def isnull_to_is_null(args: t.List) -> exp.Expr:
def isnull_to_is_null(args: BuilderArgs) -> exp.Expr:
return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))


Expand Down Expand Up @@ -2070,9 +2070,9 @@ def build_json_extract_path(
zero_based_indexing: bool = True,
arrow_req_json_type: bool = False,
json_type: t.Optional[str] = None,
) -> t.Callable[[t.List], F]:
def _builder(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
) -> t.Callable[[MutableSequence[t.Any]], F]:
def _builder(args: MutableSequence[t.Any]) -> F:
segments: list[exp.JSONPathPart] = [exp.JSONPathRoot()]
for arg in args[1:]:
if not isinstance(arg, exp.Literal):
# We use the fallback parser because we can't really transpile non-literals safely
Expand Down Expand Up @@ -2236,7 +2236,7 @@ def _builder(dtype: exp.DataType) -> exp.DataType:
return _builder


def build_timestamp_from_parts(args: t.List) -> exp.Func:
def build_timestamp_from_parts(args: BuilderArgs) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
Expand Down Expand Up @@ -2301,8 +2301,8 @@ def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateD
return self.func("SEQUENCE", start, end, step)


def build_like(expr_type: Type[E], not_like: bool = False) -> t.Callable[[list], exp.Expr]:
def _builder(args: t.List) -> exp.Expr:
def build_like(expr_type: Type[E], not_like: bool = False) -> t.Callable[[BuilderArgs], exp.Expr]:
def _builder(args: BuilderArgs) -> exp.Expr:
like_expr: exp.Expr = expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))

if escape := seq_get(args, 2):
Expand All @@ -2316,8 +2316,8 @@ def _builder(args: t.List) -> exp.Expr:
return _builder


def build_regexp_extract(expr_type: Type[E]) -> t.Callable[[list, Dialect], E]:
def _builder(args: t.List, dialect: Dialect) -> E:
def build_regexp_extract(expr_type: Type[E]) -> t.Callable[[BuilderArgs, Dialect], E]:
def _builder(args: BuilderArgs, dialect: Dialect) -> E:
# The "position" argument specifies the index of the string character to start matching from.
# `null_if_pos_overflow` reflects the dialect's behavior when position is greater than the string
# length. If true, returns NULL. If false, returns an empty string. `null_if_pos_overflow` is
Expand Down Expand Up @@ -2399,8 +2399,8 @@ def length_or_char_length_sql(self: Generator, expression: exp.Length) -> str:
def groupconcat_sql(
self: Generator,
expression: exp.GroupConcat,
func_name="LISTAGG",
sep: t.Optional[str] = ",",
func_name: str = "LISTAGG",
sep: str | None = ",",
within_group: bool = True,
on_overflow: bool = False,
) -> str:
Expand Down Expand Up @@ -2443,7 +2443,9 @@ def groupconcat_sql(
return self.sql(listagg)


def build_timetostr_or_tochar(args: t.List, dialect: DialectType) -> exp.TimeToStr | exp.ToChar:
def build_timetostr_or_tochar(
args: BuilderArgs, dialect: DialectType
) -> exp.TimeToStr | exp.ToChar:
if len(args) == 2:
this = args[0]
if not this.type:
Expand All @@ -2458,7 +2460,7 @@ def build_timetostr_or_tochar(args: t.List, dialect: DialectType) -> exp.TimeToS
return exp.ToChar.from_arg_list(args)


def build_replace_with_optional_replacement(args: t.List) -> exp.Replace:
def build_replace_with_optional_replacement(args: BuilderArgs) -> exp.Replace:
return exp.Replace(
this=seq_get(args, 0),
expression=seq_get(args, 1),
Expand Down
Loading
Loading