Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ DBDUCK_MONGO_MIN_POOL_SIZE=0
DBDUCK_MONGO_RETRY_ATTEMPTS=3
DBDUCK_MONGO_RETRY_BACKOFF_MS=100

# Security hardening (keep false in production)
DBDUCK_ALLOW_UNSAFE_WHERE_STRINGS=false
# Security hardening
DBDUCK_HASH_SENSITIVE_FIELDS=true
DBDUCK_BCRYPT_ROUNDS=12
DBDUCK_SECURITY_AUDIT_ENABLED=true
Expand Down
238 changes: 205 additions & 33 deletions DBDuck/adapters/_sqlalchemy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import builtins
import re
import warnings
from typing import Any, Mapping, Sequence

from sqlalchemy import MetaData, Table, and_, bindparam, delete as sa_delete, func, insert, inspect as sa_inspect, select, text, update as sa_update
Expand All @@ -21,7 +22,11 @@ class SQLAlchemyAdapter(BaseAdapter):

DIALECT = "sql"
IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
_DANGEROUS_SQL = re.compile(r"(?:--|/\*|\*/|;|\b(UNION|DROP|TRUNCATE|ALTER)\b)", re.IGNORECASE)
_DANGEROUS_SQL = re.compile(
r"(?:--|#|/\*|\*/|\b(UNION|DROP|TRUNCATE|ALTER|SLEEP|BENCHMARK|EXEC|EVAL|FUNCTION)\b)",
re.IGNORECASE,
)
_DDL_DANGEROUS_SQL = re.compile(r"(?:--|#|/\*|\*/|\b(DROP|TRUNCATE|ALTER)\b)", re.IGNORECASE)
_SQLALCHEMY_BACKGROUND_RE = re.compile(
r"\s*\(?\s*Background on this error at:\s*https?://sqlalche\.me/[^\s\)]*\s*\)?\s*",
re.IGNORECASE,
Expand Down Expand Up @@ -174,6 +179,63 @@ def _parse_order_by_components(self, order_by: str) -> tuple[str, str]:
self._validate_identifier(field)
return field, direction

def _is_trusted_ddl_caller(self) -> bool:
return bool(self.options.get("admin_mode", False))

def _require_admin_mode(self, action: str) -> None:
if not self._is_trusted_ddl_caller():
raise QueryError(f"{action} requires explicit admin_mode=True")

@staticmethod
def _contains_unquoted_semicolon(value: str) -> bool:
in_single = False
in_double = False
escaped = False
for char in value:
if escaped:
escaped = False
continue
if char == "\\":
escaped = True
continue
if char == "'" and not in_double:
in_single = not in_single
continue
if char == '"' and not in_single:
in_double = not in_double
continue
if char == ";" and not in_single and not in_double:
return True
return False

def _validate_admin_sql_fragment(self, fragment: str, *, field_name: str) -> str:
text_fragment = fragment.strip()
if not text_fragment:
raise QueryError(f"{field_name} must be a non-empty SQL string")
if self._contains_unquoted_semicolon(text_fragment):
raise QueryError(f"{field_name} cannot contain semicolons outside quoted strings")
if self._DDL_DANGEROUS_SQL.search(text_fragment):
raise QueryError(f"{field_name} contains prohibited SQL content")
return text_fragment

def _parse_uql_create_body(self, body: str) -> dict[str, Any]:
parsed: dict[str, Any] = {}
pairs = [p.strip() for p in body.split(",") if p.strip()]
for pair in pairs:
if ":" not in pair:
raise QueryError("Invalid CREATE UQL payload")
key, raw = pair.split(":", 1)
field_name = self._validate_identifier(key.strip())
raw_value = raw.strip()
starts_quoted = raw_value.startswith(("'", '"'))
ends_quoted = raw_value.endswith(("'", '"'))
if starts_quoted != ends_quoted:
raise QueryError("Invalid quoted literal in CREATE UQL payload")
parsed[field_name] = self._parse_literal_value(raw_value)
if not parsed:
raise QueryError("CREATE UQL requires at least one field")
return parsed

def _build_where_expression(
self,
entity: str,
Expand Down Expand Up @@ -417,8 +479,11 @@ def _build_parameterized_where_from_string(self, entity: str, where: str) -> tup
if not text_where:
return "", {}
if self._allow_unsafe_where_strings:
# Legacy compatibility mode. Keep disabled by default.
return f" WHERE {text_where}", {}
warnings.warn(
"allow_unsafe_where_strings is deprecated and no longer bypasses parameterization",
DeprecationWarning,
stacklevel=2,
)
if self._DANGEROUS_SQL.search(text_where):
raise QueryError("Potential SQL injection detected in where clause")
tokens = re.split(r"\s+(AND|OR)\s+", text_where, flags=re.IGNORECASE)
Expand Down Expand Up @@ -458,14 +523,6 @@ def _build_parameterized_where_from_string(self, entity: str, where: str) -> tup
raise QueryError("Invalid where clause structure")
return " WHERE " + " ".join(clauses), params

def _validate_uql_where_clause(self, where: str) -> str:
text_where = where.strip()
if not text_where:
raise QueryError("WHERE clause cannot be empty")
if self._DANGEROUS_SQL.search(text_where):
raise QueryError("Potential SQL injection detected in WHERE clause")
return text_where

def _validate_order_by_clause(self, order_by: str) -> str:
safe_order = order_by.strip()
match = re.fullmatch(r"([A-Za-z_][A-Za-z0-9_]*)(?:\s+(ASC|DESC))?", safe_order, re.IGNORECASE)
Expand Down Expand Up @@ -716,7 +773,7 @@ def aggregate(
stmt = stmt.limit(limit)
return self.run_native(stmt, params=params)

def convert_uql(self, uql_query: str) -> str:
def convert_uql(self, uql_query: str) -> Any:
uql = uql_query.strip()
upper = uql.upper()
if upper.startswith("FIND "):
Expand All @@ -727,19 +784,21 @@ def convert_uql(self, uql_query: str) -> str:
)
if not match:
raise QueryError("Invalid FIND UQL")
entity = match.group(1)
entity = self._validate_identifier(match.group(1))
where = match.group(2)
order_by = match.group(3)
limit = int(match.group(4)) if match.group(4) else None
sql = f"SELECT * FROM {self._quote(self._validate_identifier(entity))}" # nosec B608
sql = f"SELECT * FROM {self._quote(entity)}" # nosec B608
params: dict[str, Any] = {}
if where:
safe_where = self._validate_uql_where_clause(where)
sql += f" WHERE {safe_where}"
where_sql, params = self._build_parameterized_where_from_string(entity, where)
sql += where_sql
if order_by:
sql += f" ORDER BY {self._validate_order_by_clause(order_by)}"
if limit is not None:
sql += f" LIMIT {limit}"
return sql
sql += " LIMIT :limit_value"
params["limit_value"] = limit
return sql, params
if upper.startswith("DELETE "):
match = re.match(
r"DELETE\s+([A-Za-z_][A-Za-z0-9_]*)(?:\s+WHERE\s+(.+))?$",
Expand All @@ -752,26 +811,22 @@ def convert_uql(self, uql_query: str) -> str:
where = match.group(2)
if not where:
raise QueryError("DELETE UQL requires WHERE")
safe_where = self._validate_uql_where_clause(where)
return f"DELETE FROM {self._quote(entity)} WHERE {safe_where}" # nosec B608
where_sql, params = self._build_parameterized_where_from_string(entity, where)
if not where_sql:
raise QueryError("DELETE UQL requires WHERE")
return f"DELETE FROM {self._quote(entity)}{where_sql}", params # nosec B608
if upper.startswith("CREATE "):
match = re.match(r"CREATE\s+([A-Za-z_][A-Za-z0-9_]*)\s*\{(.+)\}$", uql, flags=re.IGNORECASE)
if not match:
raise QueryError("Invalid CREATE UQL")
entity = self._validate_identifier(match.group(1))
body = match.group(2)
pairs = [p.strip() for p in body.split(",") if p.strip()]
cols: list[str] = []
vals: list[str] = []
for pair in pairs:
if ":" not in pair:
raise QueryError("Invalid CREATE UQL payload")
key, raw = pair.split(":", 1)
key = self._validate_identifier(key.strip())
cols.append(self._quote(key))
vals.append(raw.strip())
self._ensure_table(entity, {c.strip('"`[]'): "" for c in cols})
return f"INSERT INTO {self._quote(entity)} ({', '.join(cols)}) VALUES ({', '.join(vals)})" # nosec B608
parsed_data = self._parse_uql_create_body(match.group(2))
self._ensure_table(entity, parsed_data)
table = self._get_table(entity)
normalized = {
key: self._normalize_value_for_column(entity, key, value) for key, value in parsed_data.items()
}
return insert(table).values(**normalized)
raise QueryError("Unsupported UQL command")

def begin(self):
Expand All @@ -791,3 +846,120 @@ def ping(self) -> Any:

def close(self) -> None:
self._conn_manager.dispose_engine(self.url)

def create_view(self, name: str, select_query: str, *, replace: bool = False) -> Any:
view_name = self._validate_identifier(name)
self._require_admin_mode("create_view")
select_sql = self._validate_admin_sql_fragment(select_query, field_name="select_query")
if not re.match(r"^(SELECT|WITH)\b", select_sql, flags=re.IGNORECASE):
raise QueryError("view definition must start with SELECT or WITH")
if replace:
self.drop_view(view_name, if_exists=True)
sql = f"CREATE VIEW {self._quote(view_name)} AS {select_sql}" # nosec B608
return self.run_native(sql)
Comment on lines +856 to +859

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Make replace=True atomic before dropping the old object

When replace=True, this path calls drop_view() and then issues CREATE VIEW as two separate run_native() calls. Outside an explicit db.transaction(), run_native() opens its own engine.begin() block per call, so an invalid replacement query (or a missing referenced table) will commit the drop and then fail the create, leaving the previously working view gone. The same destructive pattern is repeated in the procedure/function/event helpers below.

Useful? React with 👍 / 👎.


def drop_view(self, name: str, *, if_exists: bool = True) -> Any:
view_name = self._validate_identifier(name)
sql = f"DROP VIEW {'IF EXISTS ' if if_exists else ''}{self._quote(view_name)}" # nosec B608
return self.run_native(sql)
Comment on lines +861 to +864

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Invalidate cached schema after creating or dropping a view

find()/count() reuse reflected metadata from _table_cache and _column_type_cache, but neither create_view() nor drop_view() clears those caches. If a view has been queried once and is then recreated with a different projection via replace=True, subsequent reads on the same UDOM instance will keep using the stale column set from the old view, which can silently omit new columns or raise on removed ones.

Useful? React with 👍 / 👎.


def create_procedure(self, name: str, definition: str, *, replace: bool = False) -> Any:
proc_name = self._validate_identifier(name)
if self.DIALECT == "sqlite":
raise QueryError("stored procedures are not supported for sqlite")
self._require_admin_mode("create_procedure")
definition_sql = self._validate_admin_sql_fragment(definition, field_name="definition")
if replace:
self.drop_procedure(proc_name, if_exists=True)
keyword = "CREATE OR REPLACE PROCEDURE" if self.DIALECT == "postgres" else "CREATE PROCEDURE"
sql = f"{keyword} {self._quote(proc_name)} {definition_sql}" # nosec B608
return self.run_native(sql)

def drop_procedure(self, name: str, *, if_exists: bool = True) -> Any:
proc_name = self._validate_identifier(name)
if self.DIALECT == "sqlite":
raise QueryError("stored procedures are not supported for sqlite")
sql = f"DROP PROCEDURE {'IF EXISTS ' if if_exists else ''}{self._quote(proc_name)}" # nosec B608
return self.run_native(sql)

def call_procedure(self, name: str, params: list[Any] | tuple[Any, ...] | None = None) -> Any:
proc_name = self._validate_identifier(name)
if self.DIALECT == "sqlite":
raise QueryError("stored procedures are not supported for sqlite")
values = list(params or [])
placeholder_sql = ", ".join(f":p_{idx}" for idx in range(len(values)))
bound = {f"p_{idx}": value for idx, value in enumerate(values)}
keyword = "EXEC" if self.DIALECT == "mssql" else "CALL"
sql = f"{keyword} {self._quote(proc_name)}"
if placeholder_sql:
sql += f" {placeholder_sql}" if self.DIALECT == "mssql" else f"({placeholder_sql})"
return self.run_native(sql, params=bound)

def create_function(self, name: str, definition: str, *, replace: bool = False) -> Any:
func_name = self._validate_identifier(name)
if self.DIALECT == "sqlite":
raise QueryError("function creation is not supported for sqlite")
self._require_admin_mode("create_function")
definition_sql = self._validate_admin_sql_fragment(definition, field_name="definition")
if replace:
self.drop_function(func_name, if_exists=True)
keyword = "CREATE OR REPLACE FUNCTION" if self.DIALECT == "postgres" else "CREATE FUNCTION"
sql = f"{keyword} {self._quote(func_name)} {definition_sql}" # nosec B608
return self.run_native(sql)

def drop_function(self, name: str, *, if_exists: bool = True) -> Any:
func_name = self._validate_identifier(name)
if self.DIALECT == "sqlite":
raise QueryError("function dropping is not supported for sqlite")
sql = f"DROP FUNCTION {'IF EXISTS ' if if_exists else ''}{self._quote(func_name)}" # nosec B608
return self.run_native(sql)

def call_function(self, name: str, params: list[Any] | tuple[Any, ...] | None = None) -> Any:
func_name = self._validate_identifier(name)
values = list(params or [])
bindings = [bindparam(f"p_{idx}", value=value) for idx, value in enumerate(values)]
stmt = select(getattr(func, func_name)(*bindings).label("result"))
rows = self.run_native(stmt)
if not rows:
return None
return rows[0].get("result")

def create_event(
self,
name: str,
schedule: str,
body: str,
*,
replace: bool = False,
preserve: bool = True,
enable: bool = True,
) -> Any:
event_name = self._validate_identifier(name)
if self.DIALECT != "mysql":
raise QueryError("database events are currently supported only for mysql")
if not isinstance(schedule, str) or not schedule.strip():
raise QueryError("schedule must be a non-empty SQL fragment")
schedule_sql = schedule.strip()
if not re.fullmatch(r"EVERY\s+\d+\s+(SECOND|MINUTE|HOUR|DAY|WEEK|MONTH)", schedule_sql, flags=re.IGNORECASE):
raise QueryError("schedule must match EVERY <n> <SECOND|MINUTE|HOUR|DAY|WEEK|MONTH>")
self._require_admin_mode("create_event")
body_sql = self._validate_admin_sql_fragment(body, field_name="body")
if replace:
self.drop_event(event_name, if_exists=True)
preserve_sql = "ON COMPLETION PRESERVE" if preserve else "ON COMPLETION NOT PRESERVE"
status_sql = "ENABLE" if enable else "DISABLE"
sql = (
f"CREATE EVENT {self._quote(event_name)} "
f"ON SCHEDULE {schedule_sql} "
f"{preserve_sql} "
f"{status_sql} "
f"DO {body_sql}"
) # nosec B608
return self.run_native(sql)

def drop_event(self, name: str, *, if_exists: bool = True) -> Any:
event_name = self._validate_identifier(name)
if self.DIALECT != "mysql":
raise QueryError("database events are currently supported only for mysql")
sql = f"DROP EVENT {'IF EXISTS ' if if_exists else ''}{self._quote(event_name)}" # nosec B608
return self.run_native(sql)
25 changes: 15 additions & 10 deletions DBDuck/adapters/mssql_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re
from typing import Any, Mapping

from sqlalchemy import text

from ..core.exceptions import QueryError
from ._sqlalchemy_adapter import SQLAlchemyAdapter

Expand All @@ -28,16 +30,17 @@ def _type_for_value(self, value: Any) -> str:
return "NVARCHAR(255)"

def _ensure_table(self, entity: str, data: Mapping[str, Any]) -> None:
self._validate_identifier(entity)
quoted_table = self._quote(entity)
has_explicit_id = any(k.lower() == "id" for k in data)
cols = [] if has_explicit_id else [self._pk_column_sql()]
for key, value in data.items():
cols.append(f"{self._quote(key)} {self._type_for_value(value)}")
safe_entity = entity.replace("'", "''")
sql = (
f"IF OBJECT_ID(N'{safe_entity}', N'U') IS NULL "
f"BEGIN CREATE TABLE {quoted_table} ({', '.join(cols)}) END"
)
check = text("SELECT OBJECT_ID(:tname, N'U') AS oid")
rows = self.run_native(check, params={"tname": entity})
if rows and rows[0].get("oid") is not None:
return
sql = f"CREATE TABLE {quoted_table} ({', '.join(cols)})" # nosec B608
self.run_native(sql)

def _render_metric_sql(self, alias: str, metric: Any) -> str:
Expand Down Expand Up @@ -110,7 +113,7 @@ def paginate(
) -> Any:
return self._find_with_offset(entity, where=where, order_by=order_by, limit=limit, offset=offset)

def convert_uql(self, uql_query: str) -> str:
def convert_uql(self, uql_query: str) -> Any:
uql = uql_query.strip()
upper = uql.upper()
if upper.startswith("FIND "):
Expand All @@ -126,13 +129,15 @@ def convert_uql(self, uql_query: str) -> str:
order_by = match.group(3)
limit = int(match.group(4)) if match.group(4) else None
top_sql = f" TOP {limit}" if limit is not None else ""
sql = f"SELECT{top_sql} * FROM {self._quote(self._validate_identifier(entity))}" # nosec B608
entity_name = self._validate_identifier(entity)
sql = f"SELECT{top_sql} * FROM {self._quote(entity_name)}" # nosec B608
params: dict[str, Any] = {}
if where:
safe_where = self._validate_uql_where_clause(where)
sql += f" WHERE {safe_where}"
where_sql, params = self._build_parameterized_where_from_string(entity_name, where)
sql += where_sql
if order_by:
sql += f" ORDER BY {self._validate_order_by_clause(order_by)}"
return sql
return sql, params
return super().convert_uql(uql_query)

def aggregate(
Expand Down
Loading
Loading