-
-
Notifications
You must be signed in to change notification settings - Fork 4
new branch #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
new branch #4
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import builtins | ||
| import re | ||
| import warnings | ||
| from typing import Any, Mapping, Sequence | ||
|
|
||
| from sqlalchemy import MetaData, Table, and_, bindparam, delete as sa_delete, func, insert, inspect as sa_inspect, select, text, update as sa_update | ||
|
|
@@ -21,7 +22,11 @@ class SQLAlchemyAdapter(BaseAdapter): | |
|
|
||
| DIALECT = "sql" | ||
| IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") | ||
| _DANGEROUS_SQL = re.compile(r"(?:--|/\*|\*/|;|\b(UNION|DROP|TRUNCATE|ALTER)\b)", re.IGNORECASE) | ||
| _DANGEROUS_SQL = re.compile( | ||
| r"(?:--|#|/\*|\*/|\b(UNION|DROP|TRUNCATE|ALTER|SLEEP|BENCHMARK|EXEC|EVAL|FUNCTION)\b)", | ||
| re.IGNORECASE, | ||
| ) | ||
| _DDL_DANGEROUS_SQL = re.compile(r"(?:--|#|/\*|\*/|\b(DROP|TRUNCATE|ALTER)\b)", re.IGNORECASE) | ||
| _SQLALCHEMY_BACKGROUND_RE = re.compile( | ||
| r"\s*\(?\s*Background on this error at:\s*https?://sqlalche\.me/[^\s\)]*\s*\)?\s*", | ||
| re.IGNORECASE, | ||
|
|
@@ -174,6 +179,63 @@ def _parse_order_by_components(self, order_by: str) -> tuple[str, str]: | |
| self._validate_identifier(field) | ||
| return field, direction | ||
|
|
||
| def _is_trusted_ddl_caller(self) -> bool: | ||
| return bool(self.options.get("admin_mode", False)) | ||
|
|
||
| def _require_admin_mode(self, action: str) -> None: | ||
| if not self._is_trusted_ddl_caller(): | ||
| raise QueryError(f"{action} requires explicit admin_mode=True") | ||
|
|
||
| @staticmethod | ||
| def _contains_unquoted_semicolon(value: str) -> bool: | ||
| in_single = False | ||
| in_double = False | ||
| escaped = False | ||
| for char in value: | ||
| if escaped: | ||
| escaped = False | ||
| continue | ||
| if char == "\\": | ||
| escaped = True | ||
| continue | ||
| if char == "'" and not in_double: | ||
| in_single = not in_single | ||
| continue | ||
| if char == '"' and not in_single: | ||
| in_double = not in_double | ||
| continue | ||
| if char == ";" and not in_single and not in_double: | ||
| return True | ||
| return False | ||
|
|
||
| def _validate_admin_sql_fragment(self, fragment: str, *, field_name: str) -> str: | ||
| text_fragment = fragment.strip() | ||
| if not text_fragment: | ||
| raise QueryError(f"{field_name} must be a non-empty SQL string") | ||
| if self._contains_unquoted_semicolon(text_fragment): | ||
| raise QueryError(f"{field_name} cannot contain semicolons outside quoted strings") | ||
| if self._DDL_DANGEROUS_SQL.search(text_fragment): | ||
| raise QueryError(f"{field_name} contains prohibited SQL content") | ||
| return text_fragment | ||
|
|
||
| def _parse_uql_create_body(self, body: str) -> dict[str, Any]: | ||
| parsed: dict[str, Any] = {} | ||
| pairs = [p.strip() for p in body.split(",") if p.strip()] | ||
| for pair in pairs: | ||
| if ":" not in pair: | ||
| raise QueryError("Invalid CREATE UQL payload") | ||
| key, raw = pair.split(":", 1) | ||
| field_name = self._validate_identifier(key.strip()) | ||
| raw_value = raw.strip() | ||
| starts_quoted = raw_value.startswith(("'", '"')) | ||
| ends_quoted = raw_value.endswith(("'", '"')) | ||
| if starts_quoted != ends_quoted: | ||
| raise QueryError("Invalid quoted literal in CREATE UQL payload") | ||
| parsed[field_name] = self._parse_literal_value(raw_value) | ||
| if not parsed: | ||
| raise QueryError("CREATE UQL requires at least one field") | ||
| return parsed | ||
|
|
||
| def _build_where_expression( | ||
| self, | ||
| entity: str, | ||
|
|
@@ -417,8 +479,11 @@ def _build_parameterized_where_from_string(self, entity: str, where: str) -> tup | |
| if not text_where: | ||
| return "", {} | ||
| if self._allow_unsafe_where_strings: | ||
| # Legacy compatibility mode. Keep disabled by default. | ||
| return f" WHERE {text_where}", {} | ||
| warnings.warn( | ||
| "allow_unsafe_where_strings is deprecated and no longer bypasses parameterization", | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| if self._DANGEROUS_SQL.search(text_where): | ||
| raise QueryError("Potential SQL injection detected in where clause") | ||
| tokens = re.split(r"\s+(AND|OR)\s+", text_where, flags=re.IGNORECASE) | ||
|
|
@@ -458,14 +523,6 @@ def _build_parameterized_where_from_string(self, entity: str, where: str) -> tup | |
| raise QueryError("Invalid where clause structure") | ||
| return " WHERE " + " ".join(clauses), params | ||
|
|
||
| def _validate_uql_where_clause(self, where: str) -> str: | ||
| text_where = where.strip() | ||
| if not text_where: | ||
| raise QueryError("WHERE clause cannot be empty") | ||
| if self._DANGEROUS_SQL.search(text_where): | ||
| raise QueryError("Potential SQL injection detected in WHERE clause") | ||
| return text_where | ||
|
|
||
| def _validate_order_by_clause(self, order_by: str) -> str: | ||
| safe_order = order_by.strip() | ||
| match = re.fullmatch(r"([A-Za-z_][A-Za-z0-9_]*)(?:\s+(ASC|DESC))?", safe_order, re.IGNORECASE) | ||
|
|
@@ -716,7 +773,7 @@ def aggregate( | |
| stmt = stmt.limit(limit) | ||
| return self.run_native(stmt, params=params) | ||
|
|
||
| def convert_uql(self, uql_query: str) -> str: | ||
| def convert_uql(self, uql_query: str) -> Any: | ||
| uql = uql_query.strip() | ||
| upper = uql.upper() | ||
| if upper.startswith("FIND "): | ||
|
|
@@ -727,19 +784,21 @@ def convert_uql(self, uql_query: str) -> str: | |
| ) | ||
| if not match: | ||
| raise QueryError("Invalid FIND UQL") | ||
| entity = match.group(1) | ||
| entity = self._validate_identifier(match.group(1)) | ||
| where = match.group(2) | ||
| order_by = match.group(3) | ||
| limit = int(match.group(4)) if match.group(4) else None | ||
| sql = f"SELECT * FROM {self._quote(self._validate_identifier(entity))}" # nosec B608 | ||
| sql = f"SELECT * FROM {self._quote(entity)}" # nosec B608 | ||
| params: dict[str, Any] = {} | ||
| if where: | ||
| safe_where = self._validate_uql_where_clause(where) | ||
| sql += f" WHERE {safe_where}" | ||
| where_sql, params = self._build_parameterized_where_from_string(entity, where) | ||
| sql += where_sql | ||
| if order_by: | ||
| sql += f" ORDER BY {self._validate_order_by_clause(order_by)}" | ||
| if limit is not None: | ||
| sql += f" LIMIT {limit}" | ||
| return sql | ||
| sql += " LIMIT :limit_value" | ||
| params["limit_value"] = limit | ||
| return sql, params | ||
| if upper.startswith("DELETE "): | ||
| match = re.match( | ||
| r"DELETE\s+([A-Za-z_][A-Za-z0-9_]*)(?:\s+WHERE\s+(.+))?$", | ||
|
|
@@ -752,26 +811,22 @@ def convert_uql(self, uql_query: str) -> str: | |
| where = match.group(2) | ||
| if not where: | ||
| raise QueryError("DELETE UQL requires WHERE") | ||
| safe_where = self._validate_uql_where_clause(where) | ||
| return f"DELETE FROM {self._quote(entity)} WHERE {safe_where}" # nosec B608 | ||
| where_sql, params = self._build_parameterized_where_from_string(entity, where) | ||
| if not where_sql: | ||
| raise QueryError("DELETE UQL requires WHERE") | ||
| return f"DELETE FROM {self._quote(entity)}{where_sql}", params # nosec B608 | ||
| if upper.startswith("CREATE "): | ||
| match = re.match(r"CREATE\s+([A-Za-z_][A-Za-z0-9_]*)\s*\{(.+)\}$", uql, flags=re.IGNORECASE) | ||
| if not match: | ||
| raise QueryError("Invalid CREATE UQL") | ||
| entity = self._validate_identifier(match.group(1)) | ||
| body = match.group(2) | ||
| pairs = [p.strip() for p in body.split(",") if p.strip()] | ||
| cols: list[str] = [] | ||
| vals: list[str] = [] | ||
| for pair in pairs: | ||
| if ":" not in pair: | ||
| raise QueryError("Invalid CREATE UQL payload") | ||
| key, raw = pair.split(":", 1) | ||
| key = self._validate_identifier(key.strip()) | ||
| cols.append(self._quote(key)) | ||
| vals.append(raw.strip()) | ||
| self._ensure_table(entity, {c.strip('"`[]'): "" for c in cols}) | ||
| return f"INSERT INTO {self._quote(entity)} ({', '.join(cols)}) VALUES ({', '.join(vals)})" # nosec B608 | ||
| parsed_data = self._parse_uql_create_body(match.group(2)) | ||
| self._ensure_table(entity, parsed_data) | ||
| table = self._get_table(entity) | ||
| normalized = { | ||
| key: self._normalize_value_for_column(entity, key, value) for key, value in parsed_data.items() | ||
| } | ||
| return insert(table).values(**normalized) | ||
| raise QueryError("Unsupported UQL command") | ||
|
|
||
| def begin(self): | ||
|
|
@@ -791,3 +846,120 @@ def ping(self) -> Any: | |
|
|
||
| def close(self) -> None: | ||
| self._conn_manager.dispose_engine(self.url) | ||
|
|
||
| def create_view(self, name: str, select_query: str, *, replace: bool = False) -> Any: | ||
| view_name = self._validate_identifier(name) | ||
| self._require_admin_mode("create_view") | ||
| select_sql = self._validate_admin_sql_fragment(select_query, field_name="select_query") | ||
| if not re.match(r"^(SELECT|WITH)\b", select_sql, flags=re.IGNORECASE): | ||
| raise QueryError("view definition must start with SELECT or WITH") | ||
| if replace: | ||
| self.drop_view(view_name, if_exists=True) | ||
| sql = f"CREATE VIEW {self._quote(view_name)} AS {select_sql}" # nosec B608 | ||
| return self.run_native(sql) | ||
|
|
||
| def drop_view(self, name: str, *, if_exists: bool = True) -> Any: | ||
| view_name = self._validate_identifier(name) | ||
| sql = f"DROP VIEW {'IF EXISTS ' if if_exists else ''}{self._quote(view_name)}" # nosec B608 | ||
| return self.run_native(sql) | ||
|
Comment on lines
+861
to
+864
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
|
|
||
| def create_procedure(self, name: str, definition: str, *, replace: bool = False) -> Any: | ||
| proc_name = self._validate_identifier(name) | ||
| if self.DIALECT == "sqlite": | ||
| raise QueryError("stored procedures are not supported for sqlite") | ||
| self._require_admin_mode("create_procedure") | ||
| definition_sql = self._validate_admin_sql_fragment(definition, field_name="definition") | ||
| if replace: | ||
| self.drop_procedure(proc_name, if_exists=True) | ||
| keyword = "CREATE OR REPLACE PROCEDURE" if self.DIALECT == "postgres" else "CREATE PROCEDURE" | ||
| sql = f"{keyword} {self._quote(proc_name)} {definition_sql}" # nosec B608 | ||
| return self.run_native(sql) | ||
|
|
||
| def drop_procedure(self, name: str, *, if_exists: bool = True) -> Any: | ||
| proc_name = self._validate_identifier(name) | ||
| if self.DIALECT == "sqlite": | ||
| raise QueryError("stored procedures are not supported for sqlite") | ||
| sql = f"DROP PROCEDURE {'IF EXISTS ' if if_exists else ''}{self._quote(proc_name)}" # nosec B608 | ||
| return self.run_native(sql) | ||
|
|
||
| def call_procedure(self, name: str, params: list[Any] | tuple[Any, ...] | None = None) -> Any: | ||
| proc_name = self._validate_identifier(name) | ||
| if self.DIALECT == "sqlite": | ||
| raise QueryError("stored procedures are not supported for sqlite") | ||
| values = list(params or []) | ||
| placeholder_sql = ", ".join(f":p_{idx}" for idx in range(len(values))) | ||
| bound = {f"p_{idx}": value for idx, value in enumerate(values)} | ||
| keyword = "EXEC" if self.DIALECT == "mssql" else "CALL" | ||
| sql = f"{keyword} {self._quote(proc_name)}" | ||
| if placeholder_sql: | ||
| sql += f" {placeholder_sql}" if self.DIALECT == "mssql" else f"({placeholder_sql})" | ||
| return self.run_native(sql, params=bound) | ||
|
|
||
| def create_function(self, name: str, definition: str, *, replace: bool = False) -> Any: | ||
| func_name = self._validate_identifier(name) | ||
| if self.DIALECT == "sqlite": | ||
| raise QueryError("function creation is not supported for sqlite") | ||
| self._require_admin_mode("create_function") | ||
| definition_sql = self._validate_admin_sql_fragment(definition, field_name="definition") | ||
| if replace: | ||
| self.drop_function(func_name, if_exists=True) | ||
| keyword = "CREATE OR REPLACE FUNCTION" if self.DIALECT == "postgres" else "CREATE FUNCTION" | ||
| sql = f"{keyword} {self._quote(func_name)} {definition_sql}" # nosec B608 | ||
| return self.run_native(sql) | ||
|
|
||
| def drop_function(self, name: str, *, if_exists: bool = True) -> Any: | ||
| func_name = self._validate_identifier(name) | ||
| if self.DIALECT == "sqlite": | ||
| raise QueryError("function dropping is not supported for sqlite") | ||
| sql = f"DROP FUNCTION {'IF EXISTS ' if if_exists else ''}{self._quote(func_name)}" # nosec B608 | ||
| return self.run_native(sql) | ||
|
|
||
| def call_function(self, name: str, params: list[Any] | tuple[Any, ...] | None = None) -> Any: | ||
| func_name = self._validate_identifier(name) | ||
| values = list(params or []) | ||
| bindings = [bindparam(f"p_{idx}", value=value) for idx, value in enumerate(values)] | ||
| stmt = select(getattr(func, func_name)(*bindings).label("result")) | ||
| rows = self.run_native(stmt) | ||
| if not rows: | ||
| return None | ||
| return rows[0].get("result") | ||
|
|
||
| def create_event( | ||
| self, | ||
| name: str, | ||
| schedule: str, | ||
| body: str, | ||
| *, | ||
| replace: bool = False, | ||
| preserve: bool = True, | ||
| enable: bool = True, | ||
| ) -> Any: | ||
| event_name = self._validate_identifier(name) | ||
| if self.DIALECT != "mysql": | ||
| raise QueryError("database events are currently supported only for mysql") | ||
| if not isinstance(schedule, str) or not schedule.strip(): | ||
| raise QueryError("schedule must be a non-empty SQL fragment") | ||
| schedule_sql = schedule.strip() | ||
| if not re.fullmatch(r"EVERY\s+\d+\s+(SECOND|MINUTE|HOUR|DAY|WEEK|MONTH)", schedule_sql, flags=re.IGNORECASE): | ||
| raise QueryError("schedule must match EVERY <n> <SECOND|MINUTE|HOUR|DAY|WEEK|MONTH>") | ||
| self._require_admin_mode("create_event") | ||
| body_sql = self._validate_admin_sql_fragment(body, field_name="body") | ||
| if replace: | ||
| self.drop_event(event_name, if_exists=True) | ||
| preserve_sql = "ON COMPLETION PRESERVE" if preserve else "ON COMPLETION NOT PRESERVE" | ||
| status_sql = "ENABLE" if enable else "DISABLE" | ||
| sql = ( | ||
| f"CREATE EVENT {self._quote(event_name)} " | ||
| f"ON SCHEDULE {schedule_sql} " | ||
| f"{preserve_sql} " | ||
| f"{status_sql} " | ||
| f"DO {body_sql}" | ||
| ) # nosec B608 | ||
| return self.run_native(sql) | ||
|
|
||
| def drop_event(self, name: str, *, if_exists: bool = True) -> Any: | ||
| event_name = self._validate_identifier(name) | ||
| if self.DIALECT != "mysql": | ||
| raise QueryError("database events are currently supported only for mysql") | ||
| sql = f"DROP EVENT {'IF EXISTS ' if if_exists else ''}{self._quote(event_name)}" # nosec B608 | ||
| return self.run_native(sql) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace=Trueatomic before dropping the old objectWhen
replace=True, this path callsdrop_view()and then issuesCREATE VIEWas two separaterun_native()calls. Outside an explicitdb.transaction(),run_native()opens its ownengine.begin()block per call, so an invalid replacement query (or a missing referenced table) will commit the drop and then fail the create, leaving the previously working view gone. The same destructive pattern is repeated in the procedure/function/event helpers below.Useful? React with 👍 / 👎.