diff --git a/sqlglot/_typing.py b/sqlglot/_typing.py index 70b4915144..34ffdaae70 100644 --- a/sqlglot/_typing.py +++ b/sqlglot/_typing.py @@ -1,10 +1,11 @@ from __future__ import annotations import typing as t +from collections.abc import Mapping, Sequence if t.TYPE_CHECKING: from typing_extensions import ParamSpec - from collections.abc import Mapping + import sqlglot from sqlglot.dialects.dialect import DialectType from sqlglot.errors import ErrorLevel @@ -16,6 +17,9 @@ F = t.TypeVar("F", bound="sqlglot.exp.Func") T = t.TypeVar("T") +BuilderArgs = Sequence[t.Any] +"""Sequence of arguments passed to builder functions.""" + class _DialectArg(t.TypedDict, total=False): dialect: DialectType diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index af232b5c7f..48ae1fd333 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -4,7 +4,7 @@ import logging import typing as t import sys -from collections.abc import Sequence +from collections.abc import Iterable, MutableSequence from enum import Enum, auto from functools import reduce from builtins import type as Type @@ -58,7 +58,7 @@ DATETIME_ADD = (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.TimestampAdd) if t.TYPE_CHECKING: - from sqlglot._typing import B, E, F, GeneratorArgs, ParserArgs + from sqlglot._typing import B, E, F, GeneratorArgs, ParserArgs, BuilderArgs from typing_extensions import Unpack logger = logging.getLogger("sqlglot") @@ -1398,7 +1398,7 @@ def array_concat_sql( Dialects that propagate NULLs need to set `ARRAY_FUNCS_PROPAGATES_NULLS` to True. """ - def _build_func_call(self: Generator, func_name: str, args: Sequence[exp.Expr]) -> str: + def _build_func_call(self: Generator, func_name: str, args: BuilderArgs) -> str: """Build ARRAY_CONCAT call from a list of arguments, handling variadic vs binary nesting.""" if self.ARRAY_CONCAT_IS_VAR_LEN: return self.func(func_name, *args) @@ -1535,8 +1535,8 @@ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str: def build_formatted_time( - exp_class: Type[E], dialect: str, default: t.Optional[bool | str] = None -) -> t.Callable[[list], E]: + exp_class: Type[E], dialect: str, default: bool | str | None = None +) -> t.Callable[[BuilderArgs], E]: """Helper used for time expressions. Args: @@ -1548,7 +1548,7 @@ def build_formatted_time( A callable that can be used to return the appropriately formatted time expression. """ - def _builder(args: t.List): + def _builder(args: BuilderArgs) -> E: return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( @@ -1576,11 +1576,11 @@ def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> def build_date_delta( exp_class: Type[E], - unit_mapping: t.Optional[dict[str, str]] = None, - default_unit: t.Optional[str] = "DAY", + unit_mapping: dict[str, str] | None = None, + default_unit: str | None = "DAY", supports_timezone: bool = False, -) -> t.Callable[[list], E]: - def _builder(args: list) -> E: +) -> t.Callable[[BuilderArgs], E]: + def _builder(args: BuilderArgs) -> E: unit_based = len(args) >= 3 has_timezone = len(args) == 4 this = args[2] if unit_based else seq_get(args, 0) @@ -1598,8 +1598,8 @@ def _builder(args: list) -> E: def build_date_delta_with_interval( expression_class: Type[E], -) -> t.Callable[[list], t.Optional[E]]: - def _builder(args: list) -> t.Optional[E]: +) -> t.Callable[[BuilderArgs], t.Optional[E]]: + def _builder(args: BuilderArgs) -> t.Optional[E]: if len(args) < 2: return None @@ -1613,7 +1613,7 @@ def _builder(args: list) -> t.Optional[E]: return _builder -def date_trunc_to_time(args: list) -> exp.DateTrunc | exp.TimestampTrunc: +def date_trunc_to_time(args: BuilderArgs) -> exp.DateTrunc | exp.TimestampTrunc: unit = seq_get(args, 0) this = seq_get(args, 1) @@ -1808,8 +1808,8 @@ def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: ) -def pivot_column_names(aggregations: t.List[exp.Expr], dialect: DialectType) -> t.List[str]: - names = [] +def pivot_column_names(aggregations: Iterable[exp.Expr], dialect: DialectType) -> list[str]: + names: list[str] = [] for agg in aggregations: if isinstance(agg, exp.Alias): names.append(agg.alias) @@ -1832,17 +1832,17 @@ def pivot_column_names(aggregations: t.List[exp.Expr], dialect: DialectType) -> return names -def binary_from_function(expr_type: Type[B]) -> t.Callable[[list], B]: +def binary_from_function(expr_type: Type[B]) -> t.Callable[[BuilderArgs], B]: return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) # Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects -def build_timestamp_trunc(args: list) -> exp.TimestampTrunc: +def build_timestamp_trunc(args: BuilderArgs) -> exp.TimestampTrunc: return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) def build_trunc( - args: t.List, + args: BuilderArgs, dialect: DialectType, date_trunc_unabbreviate: bool = True, default_date_trunc_unit: t.Optional[str] = None, @@ -1898,7 +1898,7 @@ def is_parse_json(expression: exp.Expr) -> bool: ) -def isnull_to_is_null(args: t.List) -> exp.Expr: +def isnull_to_is_null(args: BuilderArgs) -> exp.Expr: return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) @@ -2070,9 +2070,9 @@ def build_json_extract_path( zero_based_indexing: bool = True, arrow_req_json_type: bool = False, json_type: t.Optional[str] = None, -) -> t.Callable[[t.List], F]: - def _builder(args: t.List) -> F: - segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] +) -> t.Callable[[MutableSequence[t.Any]], F]: + def _builder(args: MutableSequence[t.Any]) -> F: + segments: list[exp.JSONPathPart] = [exp.JSONPathRoot()] for arg in args[1:]: if not isinstance(arg, exp.Literal): # We use the fallback parser because we can't really transpile non-literals safely @@ -2236,7 +2236,7 @@ def _builder(dtype: exp.DataType) -> exp.DataType: return _builder -def build_timestamp_from_parts(args: t.List) -> exp.Func: +def build_timestamp_from_parts(args: BuilderArgs) -> exp.Func: if len(args) == 2: # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, # so we parse this into Anonymous for now instead of introducing complexity @@ -2301,8 +2301,8 @@ def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateD return self.func("SEQUENCE", start, end, step) -def build_like(expr_type: Type[E], not_like: bool = False) -> t.Callable[[list], exp.Expr]: - def _builder(args: t.List) -> exp.Expr: +def build_like(expr_type: Type[E], not_like: bool = False) -> t.Callable[[BuilderArgs], exp.Expr]: + def _builder(args: BuilderArgs) -> exp.Expr: like_expr: exp.Expr = expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) if escape := seq_get(args, 2): @@ -2316,8 +2316,8 @@ def _builder(args: t.List) -> exp.Expr: return _builder -def build_regexp_extract(expr_type: Type[E]) -> t.Callable[[list, Dialect], E]: - def _builder(args: t.List, dialect: Dialect) -> E: +def build_regexp_extract(expr_type: Type[E]) -> t.Callable[[BuilderArgs, Dialect], E]: + def _builder(args: BuilderArgs, dialect: Dialect) -> E: # The "position" argument specifies the index of the string character to start matching from. # `null_if_pos_overflow` reflects the dialect's behavior when position is greater than the string # length. If true, returns NULL. If false, returns an empty string. `null_if_pos_overflow` is @@ -2399,8 +2399,8 @@ def length_or_char_length_sql(self: Generator, expression: exp.Length) -> str: def groupconcat_sql( self: Generator, expression: exp.GroupConcat, - func_name="LISTAGG", - sep: t.Optional[str] = ",", + func_name: str = "LISTAGG", + sep: str | None = ",", within_group: bool = True, on_overflow: bool = False, ) -> str: @@ -2443,7 +2443,9 @@ def groupconcat_sql( return self.sql(listagg) -def build_timetostr_or_tochar(args: t.List, dialect: DialectType) -> exp.TimeToStr | exp.ToChar: +def build_timetostr_or_tochar( + args: BuilderArgs, dialect: DialectType +) -> exp.TimeToStr | exp.ToChar: if len(args) == 2: this = args[0] if not this.type: @@ -2458,7 +2460,7 @@ def build_timetostr_or_tochar(args: t.List, dialect: DialectType) -> exp.TimeToS return exp.ToChar.from_arg_list(args) -def build_replace_with_optional_replacement(args: t.List) -> exp.Replace: +def build_replace_with_optional_replacement(args: BuilderArgs) -> exp.Replace: return exp.Replace( this=seq_get(args, 0), expression=seq_get(args, 1), diff --git a/sqlglot/parser.py b/sqlglot/parser.py index d423c00f0d..32fa052562 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -22,9 +22,11 @@ from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import TrieResult, in_trie from collections.abc import Sequence +from builtins import type as Type if t.TYPE_CHECKING: - from sqlglot._typing import E + from sqlglot.expressions import ExpOrStr + from sqlglot._typing import E, BuilderArgs from sqlglot.dialects.dialect import Dialect, DialectType from re import Pattern @@ -40,12 +42,12 @@ TIME_ZONE_RE: Pattern[str] = re.compile(r":.*?[a-zA-Z\+\-]") -def build_var_map(args: list) -> exp.StarMap | exp.VarMap: +def build_var_map(args: BuilderArgs) -> exp.StarMap | exp.VarMap: if len(args) == 1 and args[0].is_star: return exp.StarMap(this=args[0]) - keys = [] - values = [] + keys: list[ExpOrStr] = [] + values: list[ExpOrStr] = [] for i in range(0, len(args), 2): keys.append(args[i]) values.append(args[i + 1]) @@ -53,13 +55,13 @@ def build_var_map(args: list) -> exp.StarMap | exp.VarMap: return exp.VarMap(keys=exp.array(*keys, copy=False), values=exp.array(*values, copy=False)) -def build_like(args: t.List) -> exp.Escape | exp.Like: +def build_like(args: BuilderArgs) -> exp.Escape | exp.Like: like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like def binary_range_parser( - expr_type: t.Type[exp.Expr], reverse_args: bool = False + expr_type: Type[exp.Expr], reverse_args: bool = False ) -> t.Callable[[Parser, t.Optional[exp.Expr]], t.Optional[exp.Expr]]: def _parse_binary_range(self: Parser, this: t.Optional[exp.Expr]) -> t.Optional[exp.Expr]: expression = self._parse_bitwise() @@ -70,7 +72,7 @@ def _parse_binary_range(self: Parser, this: t.Optional[exp.Expr]) -> t.Optional[ return _parse_binary_range -def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func: +def build_logarithm(args: BuilderArgs, dialect: Dialect) -> exp.Func: # Default argument order is base, expression this = seq_get(args, 0) expression = seq_get(args, 1) @@ -83,25 +85,27 @@ def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func: return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this) -def build_hex(args: t.List, dialect: Dialect) -> exp.Hex | exp.LowerHex: +def build_hex(args: BuilderArgs, dialect: Dialect) -> exp.Hex | exp.LowerHex: arg = seq_get(args, 0) return exp.LowerHex(this=arg) if dialect.HEX_LOWERCASE else exp.Hex(this=arg) -def build_lower(args: t.List) -> exp.Lower | exp.Hex: +def build_lower(args: BuilderArgs) -> exp.Lower | exp.Hex: # LOWER(HEX(..)) can be simplified to LowerHex to simplify its transpilation arg = seq_get(args, 0) return exp.LowerHex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Lower(this=arg) -def build_upper(args: t.List) -> exp.Upper | exp.Hex: +def build_upper(args: BuilderArgs) -> exp.Upper | exp.Hex: # UPPER(HEX(..)) can be simplified to Hex to simplify its transpilation arg = seq_get(args, 0) return exp.Hex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Upper(this=arg) -def build_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: - def _builder(args: t.List, dialect: Dialect) -> E: +def build_extract_json_with_path( + expr_type: Type[E], +) -> t.Callable[[BuilderArgs, Dialect], E]: + def _builder(args: BuilderArgs, dialect: Dialect) -> E: expression = expr_type( this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) ) @@ -115,7 +119,7 @@ def _builder(args: t.List, dialect: Dialect) -> E: return _builder -def build_mod(args: t.List) -> exp.Mod: +def build_mod(args: BuilderArgs) -> exp.Mod: this = seq_get(args, 0) expression = seq_get(args, 1) @@ -126,7 +130,7 @@ def build_mod(args: t.List) -> exp.Mod: return exp.Mod(this=this, expression=expression) -def build_pad(args: t.List, is_left: bool = True): +def build_pad(args: BuilderArgs, is_left: bool = True): return exp.Pad( this=seq_get(args, 0), expression=seq_get(args, 1), @@ -136,7 +140,7 @@ def build_pad(args: t.List, is_left: bool = True): def build_array_constructor( - exp_class: t.Type[E], args: t.List, bracket_kind: TokenType, dialect: Dialect + exp_class: Type[E], args: list[t.Any], bracket_kind: TokenType, dialect: Dialect ) -> exp.Expr: array_exp = exp_class(expressions=args) @@ -147,7 +151,7 @@ def build_array_constructor( def build_convert_timezone( - args: t.List, default_source_tz: t.Optional[str] = None + args: BuilderArgs, default_source_tz: t.Optional[str] = None ) -> t.Union[exp.ConvertTimezone, exp.Anonymous]: if len(args) == 2: source_tz = exp.Literal.string(default_source_tz) if default_source_tz else None @@ -158,7 +162,7 @@ def build_convert_timezone( return exp.ConvertTimezone.from_arg_list(args) -def build_trim(args: t.List, is_left: bool = True, reverse_args: bool = False): +def build_trim(args: BuilderArgs, is_left: bool = True, reverse_args: bool = False) -> exp.Trim: this, expression = seq_get(args, 0), seq_get(args, 1) if expression and reverse_args: @@ -168,12 +172,12 @@ def build_trim(args: t.List, is_left: bool = True, reverse_args: bool = False): def build_coalesce( - args: t.List, is_nvl: t.Optional[bool] = None, is_null: t.Optional[bool] = None + args: BuilderArgs, is_nvl: t.Optional[bool] = None, is_null: t.Optional[bool] = None ) -> exp.Coalesce: return exp.Coalesce(this=seq_get(args, 0), expressions=args[1:], is_nvl=is_nvl, is_null=is_null) -def build_locate_strposition(args: t.List): +def build_locate_strposition(args: BuilderArgs) -> exp.StrPosition: return exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -181,7 +185,7 @@ def build_locate_strposition(args: t.List): ) -def build_array_append(args: t.List, dialect: Dialect) -> exp.ArrayAppend: +def build_array_append(args: BuilderArgs, dialect: Dialect) -> exp.ArrayAppend: """ Builds ArrayAppend with NULL propagation semantics based on the dialect configuration. @@ -202,7 +206,7 @@ def build_array_append(args: t.List, dialect: Dialect) -> exp.ArrayAppend: ) -def build_array_prepend(args: t.List, dialect: Dialect) -> exp.ArrayPrepend: +def build_array_prepend(args: BuilderArgs, dialect: Dialect) -> exp.ArrayPrepend: """ Builds ArrayPrepend with NULL propagation semantics based on the dialect configuration. @@ -223,7 +227,7 @@ def build_array_prepend(args: t.List, dialect: Dialect) -> exp.ArrayPrepend: ) -def build_array_concat(args: t.List, dialect: Dialect) -> exp.ArrayConcat: +def build_array_concat(args: BuilderArgs, dialect: Dialect) -> exp.ArrayConcat: """ Builds ArrayConcat with NULL propagation semantics based on the dialect configuration. @@ -244,7 +248,7 @@ def build_array_concat(args: t.List, dialect: Dialect) -> exp.ArrayConcat: ) -def build_array_remove(args: t.List, dialect: Dialect) -> exp.ArrayRemove: +def build_array_remove(args: BuilderArgs, dialect: Dialect) -> exp.ArrayRemove: """ Builds ArrayRemove with NULL propagation semantics based on the dialect configuration. @@ -265,7 +269,7 @@ def build_array_remove(args: t.List, dialect: Dialect) -> exp.ArrayRemove: ) -def _resolve_dialect(dialect: t.Any) -> t.Any: +def _resolve_dialect(dialect: DialectType) -> Dialect: from sqlglot.dialects.dialect import Dialect return Dialect.get_or_raise(dialect) diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 1d39fc5574..47638b0768 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -16,7 +16,7 @@ from sqlglot.dialects.dialect import DialectType from collections.abc import Sequence - ColumnMapping = t.Union[dict, str, list] + ColumnMapping = t.Union[dict[str, str], str, list[str]] @trait @@ -151,8 +151,8 @@ def empty(self) -> bool: class AbstractMappingSchema: def __init__( self, - mapping: t.Optional[t.Dict] = None, - udf_mapping: t.Optional[t.Dict] = None, + mapping: dict[str, object] | None = None, + udf_mapping: dict[str, object] | None = None, ) -> None: self.mapping = mapping or {} self.mapping_trie = new_trie( @@ -201,10 +201,10 @@ def udf_parts(self, udf: exp.Anonymous) -> t.List[str]: def _find_in_trie( self, - parts: t.List[str], - trie: t.Dict, + parts: list[str], + trie: dict[str, object], raise_on_missing: bool, - ) -> t.Optional[t.List[str]]: + ) -> list[str] | None: value, trie = in_trie(trie, parts) if value == TrieResult.FAILED: @@ -271,7 +271,10 @@ def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Opti ) def nested_get( - self, parts: Sequence[str], d: t.Optional[dict] = None, raise_on_missing=True + self, + parts: Sequence[str], + d: dict[str, object] | None = None, + raise_on_missing: bool = True, ) -> t.Optional[t.Any]: return nested_get( d or self.mapping, @@ -301,11 +304,11 @@ class MappingSchema(AbstractMappingSchema, Schema): def __init__( self, - schema: t.Optional[t.Dict] = None, - visible: t.Optional[t.Dict] = None, + schema: dict[str, object] | None = None, + visible: dict[str, object] | None = None, dialect: DialectType = None, normalize: bool = True, - udf_mapping: t.Optional[t.Dict] = None, + udf_mapping: dict[str, object] | None = None, ) -> None: self.visible = {} if visible is None else visible self.normalize = normalize @@ -499,7 +502,7 @@ def has_column( table_schema = self.find(normalized_table, raise_on_missing=False) return normalized_column_name in table_schema if table_schema else False - def _normalize(self, schema: t.Dict) -> t.Dict: + def _normalize(self, schema: dict[str, object]) -> dict[str, object]: """ Normalizes all identifiers in the schema. @@ -509,7 +512,7 @@ def _normalize(self, schema: t.Dict) -> t.Dict: Returns: The normalized schema mapping. """ - normalized_mapping: t.Dict = {} + normalized_mapping: dict[str, object] = {} flattened_schema = flatten_schema(schema) error_msg = "Table {} must match the schema's nesting level: {}." @@ -537,7 +540,7 @@ def _normalize(self, schema: t.Dict) -> t.Dict: return normalized_mapping - def _normalize_udfs(self, udfs: t.Dict) -> t.Dict: + def _normalize_udfs(self, udfs: dict[str, object]) -> dict[str, object]: """ Normalizes all identifiers in the UDF mapping. @@ -547,7 +550,7 @@ def _normalize_udfs(self, udfs: t.Dict) -> t.Dict: Returns: The normalized UDF mapping. """ - normalized_mapping: t.Dict = {} + normalized_mapping: dict[str, object] = {} for keys in flatten_schema(udfs, depth=dict_depth(udfs)): udf_type = nested_get(udfs, *zip(keys, keys)) @@ -695,14 +698,20 @@ def normalize_name( return Dialect.get_or_raise(dialect).normalize_identifier(identifier) -def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: +def ensure_schema(schema: Schema | dict[str, object] | None, **kwargs: t.Any) -> Schema: if isinstance(schema, Schema): return schema return MappingSchema(schema, **kwargs) -def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: +@t.overload +def ensure_column_mapping(mapping: list[str] | None) -> dict[str, None]: ... +@t.overload +def ensure_column_mapping(mapping: dict[str, str]) -> dict[str, str]: ... +@t.overload +def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, None] | dict[str, str]: ... +def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, None] | dict[str, str]: if mapping is None: return {} elif isinstance(mapping, dict): @@ -720,8 +729,8 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: def flatten_schema( - schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None -) -> t.List[t.List[str]]: + schema: dict[str, object], depth: int | None = None, keys: list[str] | None = None +) -> list[list[str]]: tables = [] keys = keys or [] depth = dict_depth(schema) - 1 if depth is None else depth @@ -736,7 +745,7 @@ def flatten_schema( def nested_get( - d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True + d: dict[str, object], *path: tuple[str, str], raise_on_missing: bool = True ) -> t.Optional[t.Any]: """ Get a value for a nested dictionary. @@ -762,7 +771,7 @@ def nested_get( return result -def nested_set(d: dict, keys: Sequence[str], value: t.Any) -> dict: +def nested_set(d: dict[str, t.Any], keys: Sequence[str], value: t.Any) -> dict[str, t.Any]: """ In-place set a value for a nested dictionary