diff --git a/CLAUDE.md b/CLAUDE.md index e5c89fa7..ea744fbc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -47,6 +47,10 @@ - Fix all issues, including pre-existing ones unrelated to your changes. The codebase must always be warning-free. - **Do not suppress lint warnings with ignore comments by default.** Always try to fix the issue properly first. Only add a suppress comment when the lint rule genuinely does not apply and a proper fix would be worse (e.g. less readable, wrong behavior). Include a clear justification in the comment. +## Audits + +- **When fixing an item from any audit in `audits/`**, always mark it as fixed in the audit file (strikethrough title, add ✅ FIXED, add a note with what was done). + ## Bug Fixes - **When fixing a bug, always search the entire codebase for the same pattern** before considering the fix done. Proactively find and fix all similar occurrences. diff --git a/backend/alembic/versions/e21044117ef6_add_indexes_on_attribute_handler_id_.py b/backend/alembic/versions/e21044117ef6_add_indexes_on_attribute_handler_id_.py new file mode 100644 index 00000000..8e6d0c1a --- /dev/null +++ b/backend/alembic/versions/e21044117ef6_add_indexes_on_attribute_handler_id_.py @@ -0,0 +1,36 @@ +"""add indexes on attribute handler_id enabled and action handler_id + +Revision ID: e21044117ef6 +Revises: 96db4b158ec4 +Create Date: 2026-03-09 09:13:01.929799 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'e21044117ef6' +down_revision: Union[str, Sequence[str], None] = '96db4b158ec4' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f('ix_actions_handler_id'), 'actions', ['handler_id'], unique=False) + op.create_index(op.f('ix_attributes_enabled'), 'attributes', ['enabled'], unique=False) + op.create_index(op.f('ix_attributes_handler_id'), 'attributes', ['handler_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_attributes_handler_id'), table_name='attributes') + op.drop_index(op.f('ix_attributes_enabled'), table_name='attributes') + op.drop_index(op.f('ix_actions_handler_id'), table_name='actions') + # ### end Alembic commands ### diff --git a/backend/app/api/actions.py b/backend/app/api/actions.py index 40d8a058..53561999 100644 --- a/backend/app/api/actions.py +++ b/backend/app/api/actions.py @@ -4,8 +4,9 @@ from app.dependencies import CurrentUser, DbSession, HandlerManagerDep from app.models.action import Action from app.schemas.action import ActionCreate, ActionRead, ActionUpdate, ExecuteActionRequest -from app.socketio_app import sio +from app.socketio_app import emit_mutate from app.utils.action_params import merge_params +from app.utils.db import get_or_404 router = APIRouter(prefix="/actions", tags=["actions"]) @@ -18,11 +19,7 @@ async def list_actions(db: DbSession, _current_user: CurrentUser): @router.get("/{action_id}", response_model=ActionRead) async def get_action(action_id: int, db: DbSession, _current_user: CurrentUser): - result = await db.execute(select(Action).where(Action.id == action_id)) - action = result.scalar_one_or_none() - if not action: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Action not found") - return action + return await get_or_404(db, Action, action_id) @router.post("/", response_model=ActionRead, status_code=status.HTTP_201_CREATED) @@ -31,8 +28,7 @@ async def create_action(body: ActionCreate, db: DbSession, _current_user: Curren db.add(action) await db.commit() await db.refresh(action) - await sio.emit("mutate", {"entity": "actions"}) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("actions", "handlers") return action @@ -40,31 +36,23 @@ async def create_action(body: ActionCreate, db: DbSession, _current_user: Curren async def update_action( action_id: int, body: ActionUpdate, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser ): - result = await db.execute(select(Action).where(Action.id == action_id)) - action = result.scalar_one_or_none() - if not action: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Action not found") + action = await get_or_404(db, Action, action_id) for field, value in body.model_dump(exclude_unset=True).items(): setattr(action, field, value) await db.commit() await db.refresh(action) manager.invalidate_action_cache(action_id) - await sio.emit("mutate", {"entity": "actions"}) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("actions", "handlers") return action @router.delete("/{action_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_action(action_id: int, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser): - result = await db.execute(select(Action).where(Action.id == action_id)) - action = result.scalar_one_or_none() - if not action: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Action not found") + action = await get_or_404(db, Action, action_id) await db.delete(action) await db.commit() manager.invalidate_action_cache(action_id) - await sio.emit("mutate", {"entity": "actions"}) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("actions", "handlers") @router.post("/{action_id}/execute") @@ -75,10 +63,7 @@ async def execute_action( _current_user: CurrentUser, body: ExecuteActionRequest | None = None, ): - result = await db.execute(select(Action).where(Action.id == action_id)) - action = result.scalar_one_or_none() - if not action: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Action not found") + action = await get_or_404(db, Action, action_id) if not action.handler_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Action has no handler") message = action.message diff --git a/backend/app/api/attributes.py b/backend/app/api/attributes.py index 4c5ab433..257ee01c 100644 --- a/backend/app/api/attributes.py +++ b/backend/app/api/attributes.py @@ -1,7 +1,7 @@ import datetime import logging -from fastapi import APIRouter, HTTPException, Query, status +from fastapi import APIRouter, Query, status from sqlalchemy import select, text from app.dependencies import CurrentUser, DbSession, HandlerManagerDep @@ -19,7 +19,8 @@ OutOfRangeDeleteRequest, OutOfRangeDeleteResult, ) -from app.socketio_app import sio +from app.socketio_app import emit_mutate +from app.utils.db import get_or_404 logger = logging.getLogger(__name__) router = APIRouter(prefix="/attributes", tags=["attributes"]) @@ -45,10 +46,8 @@ async def _refresh_daily_stats(dates: list[datetime.date], tz_offset: int = 0) - async with engine.connect() as conn: await conn.execution_options(isolation_level="AUTOCOMMIT") await conn.execute( - text( - f"CALL refresh_continuous_aggregate('daily_stats', " - f"'{start.isoformat()}'::timestamptz, '{end.isoformat()}'::timestamptz)" - ) + text("CALL refresh_continuous_aggregate('daily_stats', :start::timestamptz, :end::timestamptz)"), + {"start": start.isoformat(), "end": end.isoformat()}, ) except Exception: logger.warning("Could not refresh daily_stats continuous aggregate", exc_info=True) @@ -109,8 +108,7 @@ async def reorder_attributes(body: AttributeReorderRequest, db: DbSession, _curr if attr := attrs_by_id.get(item.id): attr.order = item.order await db.commit() - await sio.emit("mutate", {"entity": "attributes"}) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("attributes", "handlers") def _date_range(date: datetime.date, tz_offset: int = 0) -> dict: @@ -134,7 +132,6 @@ async def _count_out_of_range_by_date( tz_offset: int = 0, ) -> list[OutOfRangeByDateItem]: """Count out-of-range values per attribute for a single date using caller-supplied bounds.""" - date_params = _date_range(date, tz_offset) items = [] for bounds in body_attributes: attr = db_attributes.get(bounds.attribute_id) @@ -143,17 +140,10 @@ async def _count_out_of_range_by_date( if bounds.min_value is None and bounds.max_value is None: continue - range_condition = _build_data_range_condition(bounds.min_value, bounds.max_value) - count_result = await db.execute( - text( - "SELECT count(*) FROM data_units " - "WHERE attribute_id = :attr_id " - f"AND timestamp >= :date_start AND timestamp < :date_end AND ({range_condition})" - ), - {"attr_id": bounds.attribute_id, **date_params, "min_val": bounds.min_value, "max_val": bounds.max_value}, + count = await _count_data_out_of_range( + db, bounds.attribute_id, date, bounds.min_value, bounds.max_value, tz_offset ) - count = count_result.scalar() - if count and count > 0: + if count > 0: items.append(OutOfRangeByDateItem(attribute_id=attr.id, attribute_name=attr.name, count=count)) return items @@ -179,7 +169,6 @@ async def delete_out_of_range_by_date(body: OutOfRangeByDateDeleteRequest, db: D result = await db.execute(select(Attribute).where(Attribute.id.in_(ids))) db_attributes = {a.id: a for a in result.scalars()} - date_params = _date_range(body.date, body.tz_offset) total_deleted = 0 for bounds in body.attributes: if bounds.attribute_id not in db_attributes: @@ -187,21 +176,9 @@ async def delete_out_of_range_by_date(body: OutOfRangeByDateDeleteRequest, db: D if bounds.min_value is None and bounds.max_value is None: continue - range_condition = _build_data_range_condition(bounds.min_value, bounds.max_value) - delete_result = await db.execute( - text( - "DELETE FROM data_units " - "WHERE attribute_id = :attr_id " - f"AND timestamp >= :date_start AND timestamp < :date_end AND ({range_condition})" - ), - { - "attr_id": bounds.attribute_id, - **date_params, - "min_val": bounds.min_value, - "max_val": bounds.max_value, - }, + total_deleted += await _delete_data_out_of_range( + db, bounds.attribute_id, body.date, bounds.min_value, bounds.max_value, body.tz_offset ) - total_deleted += delete_result.rowcount await db.commit() if total_deleted > 0: @@ -211,11 +188,7 @@ async def delete_out_of_range_by_date(body: OutOfRangeByDateDeleteRequest, db: D @router.get("/{attribute_id}", response_model=AttributeRead) async def get_attribute(attribute_id: int, db: DbSession, _current_user: CurrentUser): - result = await db.execute(select(Attribute).where(Attribute.id == attribute_id)) - attribute = result.scalar_one_or_none() - if not attribute: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Attribute not found") - return attribute + return await get_or_404(db, Attribute, attribute_id) @router.get("/{attribute_id}/value", response_model=AttributeValue | None) @@ -235,35 +208,25 @@ async def create_attribute( await db.commit() await db.refresh(attribute) await manager.register_attribute(attribute) - await sio.emit("mutate", {"entity": "attributes"}) - await sio.emit("mutate", {"entity": "handlers"}) - await sio.emit("mutate", {"entity": "widgets"}) + await emit_mutate("attributes", "handlers", "widgets") return attribute @router.delete("/{attribute_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_attribute(attribute_id: int, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser): - result = await db.execute(select(Attribute).where(Attribute.id == attribute_id)) - attribute = result.scalar_one_or_none() - if not attribute: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Attribute not found") + attribute = await get_or_404(db, Attribute, attribute_id) await manager.unregister_attribute(attribute.handler_id, attribute.name) await db.delete(attribute) await db.commit() - await sio.emit("mutate", {"entity": "attributes"}) - await sio.emit("mutate", {"entity": "handlers"}) - await sio.emit("mutate", {"entity": "widgets"}) + await emit_mutate("attributes", "handlers", "widgets") @router.patch("/{attribute_id}", response_model=AttributeRead) async def update_attribute( attribute_id: int, body: AttributeUpdate, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser ): - result = await db.execute(select(Attribute).where(Attribute.id == attribute_id)) - attribute = result.scalar_one_or_none() - if not attribute: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Attribute not found") + attribute = await get_or_404(db, Attribute, attribute_id) for field, value in body.model_dump(exclude_unset=True).items(): setattr(attribute, field, value) @@ -278,30 +241,66 @@ async def update_attribute( tracker.min_value = attribute.min_value tracker.max_value = attribute.max_value - await sio.emit("mutate", {"entity": "attributes"}) - await sio.emit("mutate", {"entity": "handlers"}) - await sio.emit("mutate", {"entity": "widgets"}) + await emit_mutate("attributes", "handlers", "widgets") return attribute -def _build_data_range_condition(min_value: float | None, max_value: float | None) -> str: - """Build SQL condition for out-of-range values in data_units.""" +def _build_range_condition( + min_value: float | None, max_value: float | None, min_col: str = "value", max_col: str = "value" +) -> str: + """Build SQL condition for out-of-range values. + + For data_units: min_col="value", max_col="value" (default). + For daily_stats: min_col="min_value", max_col="max_value". + """ parts = [] if min_value is not None: - parts.append("value < :min_val") + parts.append(f"{min_col} < :min_val") if max_value is not None: - parts.append("value > :max_val") + parts.append(f"{max_col} > :max_val") return " OR ".join(parts) -def _build_stats_range_condition(min_value: float | None, max_value: float | None) -> str: - """Build SQL condition for daily_stats rows that may contain out-of-range values.""" - parts = [] - if min_value is not None: - parts.append("min_value < :min_val") - if max_value is not None: - parts.append("max_value > :max_val") - return " OR ".join(parts) +async def _count_data_out_of_range( + db: DbSession, + attr_id: int, + date: datetime.date, + min_value: float | None, + max_value: float | None, + tz_offset: int = 0, +) -> int: + """Count out-of-range data_units for an attribute on a single date.""" + condition = _build_range_condition(min_value, max_value) + result = await db.execute( + text( + "SELECT count(*) FROM data_units " + "WHERE attribute_id = :attr_id " + f"AND timestamp >= :date_start AND timestamp < :date_end AND ({condition})" + ), + {"attr_id": attr_id, **_date_range(date, tz_offset), "min_val": min_value, "max_val": max_value}, + ) + return result.scalar() or 0 + + +async def _delete_data_out_of_range( + db: DbSession, + attr_id: int, + date: datetime.date, + min_value: float | None, + max_value: float | None, + tz_offset: int = 0, +) -> int: + """Delete out-of-range data_units for an attribute on a single date.""" + condition = _build_range_condition(min_value, max_value) + result = await db.execute( + text( + "DELETE FROM data_units " + "WHERE attribute_id = :attr_id " + f"AND timestamp >= :date_start AND timestamp < :date_end AND ({condition})" + ), + {"attr_id": attr_id, **_date_range(date, tz_offset), "min_val": min_value, "max_val": max_value}, + ) + return result.rowcount @router.get("/{attribute_id}/out-of-range", response_model=list[OutOfRangeDayItem]) @@ -311,16 +310,13 @@ async def check_out_of_range( _current_user: CurrentUser, tz_offset: int = Query(0, ge=-720, le=840, description="Browser timezone offset in minutes (JS getTimezoneOffset)"), ): - result = await db.execute(select(Attribute).where(Attribute.id == attribute_id)) - attribute = result.scalar_one_or_none() - if not attribute: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Attribute not found") + attribute = await get_or_404(db, Attribute, attribute_id) if attribute.min_value is None and attribute.max_value is None: return [] # Step 1: Find candidate days from daily_stats aggregate - stats_condition = _build_stats_range_condition(attribute.min_value, attribute.max_value) + stats_condition = _build_range_condition(attribute.min_value, attribute.max_value, "min_value", "max_value") stats_result = await db.execute( text( "SELECT bucket::date AS date FROM daily_stats " @@ -345,20 +341,12 @@ async def check_out_of_range( candidate_dates.add(d - datetime.timedelta(days=1)) # Step 2: Count exact violations per candidate day (one tight-range query per day for chunk exclusion) - range_condition = _build_data_range_condition(attribute.min_value, attribute.max_value) results = [] for date in sorted(candidate_dates): - date_params = _date_range(date, tz_offset) - count_result = await db.execute( - text( - "SELECT count(*) FROM data_units " - "WHERE attribute_id = :attr_id " - f"AND timestamp >= :date_start AND timestamp < :date_end AND ({range_condition})" - ), - {"attr_id": attribute_id, **date_params, "min_val": attribute.min_value, "max_val": attribute.max_value}, + count = await _count_data_out_of_range( + db, attribute_id, date, attribute.min_value, attribute.max_value, tz_offset ) - count = count_result.scalar() - if count and count > 0: + if count > 0: results.append(OutOfRangeDayItem(date=date, count=count)) return results @@ -367,10 +355,7 @@ async def check_out_of_range( async def delete_out_of_range( attribute_id: int, body: OutOfRangeDeleteRequest, db: DbSession, _current_user: CurrentUser ): - result = await db.execute(select(Attribute).where(Attribute.id == attribute_id)) - attribute = result.scalar_one_or_none() - if not attribute: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Attribute not found") + attribute = await get_or_404(db, Attribute, attribute_id) if attribute.min_value is None and attribute.max_value is None: return OutOfRangeDeleteResult(deleted=0) @@ -378,19 +363,11 @@ async def delete_out_of_range( if not body.dates: return OutOfRangeDeleteResult(deleted=0) - range_condition = _build_data_range_condition(attribute.min_value, attribute.max_value) total_deleted = 0 for date in body.dates: - date_params = _date_range(date, body.tz_offset) - delete_result = await db.execute( - text( - "DELETE FROM data_units " - "WHERE attribute_id = :attr_id " - f"AND timestamp >= :date_start AND timestamp < :date_end AND ({range_condition})" - ), - {"attr_id": attribute_id, **date_params, "min_val": attribute.min_value, "max_val": attribute.max_value}, + total_deleted += await _delete_data_out_of_range( + db, attribute_id, date, attribute.min_value, attribute.max_value, body.tz_offset ) - total_deleted += delete_result.rowcount await db.commit() if total_deleted > 0: await _refresh_daily_stats(body.dates, body.tz_offset) diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 58e9286c..63587bc1 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -17,7 +17,8 @@ revoke_refresh_token, rotate_refresh_token, ) -from app.socketio_app import sio +from app.socketio_app import emit_mutate +from app.utils.db import get_or_404 from app.utils.security import hash_password, verify_password router = APIRouter(prefix="/auth", tags=["auth"]) @@ -179,7 +180,7 @@ async def create_user(body: UserCreate, db: DbSession, current_user: CurrentUser db.add(user) await db.commit() await db.refresh(user) - await sio.emit("mutate", {"entity": "auth"}) + await emit_mutate("auth") return user @@ -188,10 +189,7 @@ async def update_user(user_id: int, body: UserUpdate, db: DbSession, current_use if current_user.role != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin only") - result = await db.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() - if not user: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + user = await get_or_404(db, User, user_id) update_data = body.model_dump(exclude_unset=True) @@ -216,7 +214,7 @@ async def update_user(user_id: int, body: UserUpdate, db: DbSession, current_use await db.commit() await db.refresh(user) - await sio.emit("mutate", {"entity": "auth"}) + await emit_mutate("auth") return user @@ -228,12 +226,9 @@ async def delete_user(user_id: int, db: DbSession, current_user: CurrentUser): if current_user.id == user_id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot delete yourself") - result = await db.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() - if not user: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + user = await get_or_404(db, User, user_id) await revoke_all_user_tokens(db, user_id) await db.delete(user) await db.commit() - await sio.emit("mutate", {"entity": "auth"}) + await emit_mutate("auth") diff --git a/backend/app/api/handlers.py b/backend/app/api/handlers.py index 9c0647b7..4542474e 100644 --- a/backend/app/api/handlers.py +++ b/backend/app/api/handlers.py @@ -36,7 +36,8 @@ SerialPortInfo, SetupRequest, ) -from app.socketio_app import sio +from app.socketio_app import emit_mutate +from app.utils.db import get_or_404 from app.utils.linearize import linearize logger = logging.getLogger(__name__) @@ -258,8 +259,7 @@ async def setup_handler(body: SetupRequest, db: DbSession, manager: HandlerManag await db.commit() await manager.start_handler(handler.id) - await sio.emit("mutate", {"entity": "handlers"}) - await sio.emit("mutate", {"entity": "actions"}) + await emit_mutate("handlers", "actions") return _to_handler_read(handler) @@ -272,7 +272,7 @@ async def reorder_handlers(body: HandlerReorderRequest, db: DbSession, _current_ if handler := handlers_by_id.get(item.id): handler.order = item.order await db.commit() - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("handlers") @router.post("/", response_model=HandlerRead, status_code=status.HTTP_201_CREATED) @@ -285,16 +285,13 @@ async def create_handler(body: HandlerCreate, db: DbSession, manager: HandlerMan await db.refresh(handler, attribute_names=["attributes", "actions"]) await db.commit() await manager.start_handler(handler.id) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("handlers") return _to_handler_read(handler) @router.get("/{handler_id}", response_model=HandlerRead) async def get_handler(handler_id: int, db: DbSession, _current_user: CurrentUser): - result = await db.execute(select(Handler).where(Handler.id == handler_id).options(*_HANDLER_LOAD_OPTIONS)) - handler = result.scalar_one_or_none() - if not handler: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Handler not found") + handler = await get_or_404(db, Handler, handler_id, options=_HANDLER_LOAD_OPTIONS) return _to_handler_read(handler) @@ -302,10 +299,7 @@ async def get_handler(handler_id: int, db: DbSession, _current_user: CurrentUser async def update_handler( handler_id: int, body: HandlerUpdate, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser ): - result = await db.execute(select(Handler).where(Handler.id == handler_id).options(*_HANDLER_LOAD_OPTIONS)) - handler = result.scalar_one_or_none() - if not handler: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Handler not found") + handler = await get_or_404(db, Handler, handler_id, options=_HANDLER_LOAD_OPTIONS) update_data = body.model_dump(exclude_unset=True) if "options" in update_data and "config" in update_data["options"]: @@ -325,7 +319,7 @@ async def update_handler( await db.commit() await db.refresh(handler, attribute_names=["attributes", "actions"]) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("handlers") if needs_restart: await manager.restart_handler(handler_id) @@ -335,41 +329,31 @@ async def update_handler( @router.delete("/{handler_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_handler(handler_id: int, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser): - result = await db.execute(select(Handler).where(Handler.id == handler_id)) - handler = result.scalar_one_or_none() - if not handler: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Handler not found") + handler = await get_or_404(db, Handler, handler_id) await manager.remove_handler(handler_id) await db.delete(handler) await db.commit() - await sio.emit("mutate", {"entity": "handlers"}) - await sio.emit("mutate", {"entity": "actions"}) + await emit_mutate("handlers", "actions") @router.post("/{handler_id}/start", status_code=status.HTTP_200_OK) async def start_handler(handler_id: int, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser): - result = await db.execute(select(Handler).where(Handler.id == handler_id)) - handler = result.scalar_one_or_none() - if not handler: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Handler not found") + handler = await get_or_404(db, Handler, handler_id) handler.enabled = True await db.commit() await manager.start_handler(handler_id) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("handlers") return {"ok": True} @router.post("/{handler_id}/stop", status_code=status.HTTP_200_OK) async def stop_handler(handler_id: int, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser): - result = await db.execute(select(Handler).where(Handler.id == handler_id)) - handler = result.scalar_one_or_none() - if not handler: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Handler not found") + handler = await get_or_404(db, Handler, handler_id) handler.enabled = False await db.commit() await manager.stop_handler(handler_id) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("handlers") return {"ok": True} @@ -390,10 +374,7 @@ async def available_attributes(handler_id: int, manager: HandlerManagerDep, _cur @router.get("/{handler_id}/controls", response_model=list[ResolvedControl]) async def handler_controls(handler_id: int, db: DbSession, manager: HandlerManagerDep, _current_user: CurrentUser): - result = await db.execute(select(Handler).where(Handler.id == handler_id).options(*_HANDLER_LOAD_OPTIONS)) - handler = result.scalar_one_or_none() - if not handler: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Handler not found") + handler = await get_or_404(db, Handler, handler_id, options=_HANDLER_LOAD_OPTIONS) cls = get_handler_class(handler.type) if not cls or not cls.known_controls: diff --git a/backend/app/api/logs.py b/backend/app/api/logs.py index 358d33c4..e354d9a5 100644 --- a/backend/app/api/logs.py +++ b/backend/app/api/logs.py @@ -6,7 +6,7 @@ from app.dependencies import CurrentUser, DbSession from app.models.logging_message import LoggingMessage from app.schemas.logging_message import LoggingMessageRead -from app.socketio_app import sio +from app.socketio_app import emit_mutate router = APIRouter(prefix="/logs", tags=["logs"]) @@ -44,4 +44,4 @@ async def list_logs( async def clear_logs(db: DbSession, _current_user: CurrentUser): await db.execute(delete(LoggingMessage)) await db.commit() - await sio.emit("mutate", {"entity": "logs"}) + await emit_mutate("logs") diff --git a/backend/app/api/sub_workflows.py b/backend/app/api/sub_workflows.py index 019703b4..22baa105 100644 --- a/backend/app/api/sub_workflows.py +++ b/backend/app/api/sub_workflows.py @@ -12,7 +12,8 @@ SubWorkflowUpdate, ) from app.schemas.workflow import WorkflowData -from app.socketio_app import sio +from app.socketio_app import emit_mutate +from app.utils.db import get_or_404 logger = logging.getLogger(__name__) @@ -27,11 +28,7 @@ async def list_sub_workflows(db: DbSession, _current_user: CurrentUser): @router.get("/{sub_workflow_id}", response_model=SubWorkflowRead) async def get_sub_workflow(sub_workflow_id: int, db: DbSession, _current_user: CurrentUser): - result = await db.execute(select(SubWorkflow).where(SubWorkflow.id == sub_workflow_id)) - sw = result.scalar_one_or_none() - if not sw: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Sub-workflow not found") - return sw + return await get_or_404(db, SubWorkflow, sub_workflow_id, detail="Sub-workflow not found") @router.post("/", response_model=SubWorkflowRead, status_code=status.HTTP_201_CREATED) @@ -40,21 +37,18 @@ async def create_sub_workflow(body: SubWorkflowCreate, db: DbSession, current_us db.add(sw) await db.commit() await db.refresh(sw) - await sio.emit("mutate", {"entity": "sub_workflows"}) + await emit_mutate("sub_workflows") return sw @router.patch("/{sub_workflow_id}", response_model=SubWorkflowRead) async def update_sub_workflow(sub_workflow_id: int, body: SubWorkflowUpdate, db: DbSession, _current_user: CurrentUser): - result = await db.execute(select(SubWorkflow).where(SubWorkflow.id == sub_workflow_id)) - sw = result.scalar_one_or_none() - if not sw: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Sub-workflow not found") + sw = await get_or_404(db, SubWorkflow, sub_workflow_id, detail="Sub-workflow not found") for field, value in body.model_dump(exclude_unset=True).items(): setattr(sw, field, value) await db.commit() await db.refresh(sw) - await sio.emit("mutate", {"entity": "sub_workflows"}) + await emit_mutate("sub_workflows") return sw @@ -66,10 +60,7 @@ async def save_sub_workflow_graph( _current_user: CurrentUser, manager: HandlerManagerDep, ): - result = await db.execute(select(SubWorkflow).where(SubWorkflow.id == sub_workflow_id)) - sw = result.scalar_one_or_none() - if not sw: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Sub-workflow not found") + sw = await get_or_404(db, SubWorkflow, sub_workflow_id, detail="Sub-workflow not found") graph_dict = body.model_dump(by_alias=True) @@ -102,8 +93,7 @@ async def save_sub_workflow_graph( # Rebuild the main workflow so it picks up updated sub-workflow graphs await manager.rebuild_workflow(db) - await sio.emit("mutate", {"entity": "sub_workflows"}) - await sio.emit("mutate", {"entity": "workflow"}) + await emit_mutate("sub_workflows", "workflow") logger.info("Sub-workflow %d graph saved", sub_workflow_id) return graph_dict @@ -112,10 +102,7 @@ async def save_sub_workflow_graph( async def delete_sub_workflow( sub_workflow_id: int, db: DbSession, _current_user: CurrentUser, manager: HandlerManagerDep ): - result = await db.execute(select(SubWorkflow).where(SubWorkflow.id == sub_workflow_id)) - sw = result.scalar_one_or_none() - if not sw: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Sub-workflow not found") + sw = await get_or_404(db, SubWorkflow, sub_workflow_id, detail="Sub-workflow not found") # Check if any workflow or other sub-workflow references this one refs = await _find_references(db, sub_workflow_id) @@ -130,8 +117,7 @@ async def delete_sub_workflow( await manager.rebuild_workflow(db) - await sio.emit("mutate", {"entity": "sub_workflows"}) - await sio.emit("mutate", {"entity": "workflow"}) + await emit_mutate("sub_workflows", "workflow") async def _find_references(db: DbSession, sub_workflow_id: int) -> list[str]: diff --git a/backend/app/api/widgets.py b/backend/app/api/widgets.py index 64af21b0..cb30cdb2 100644 --- a/backend/app/api/widgets.py +++ b/backend/app/api/widgets.py @@ -18,8 +18,9 @@ WidgetReorderRequest, WidgetUpdate, ) -from app.socketio_app import sio +from app.socketio_app import emit_mutate from app.utils.action_params import merge_params +from app.utils.db import get_or_404 router = APIRouter(prefix="/widgets", tags=["widgets"]) @@ -50,16 +51,13 @@ async def create_widget(body: WidgetCreate, db: DbSession, _current_user: Curren db.add(widget) await db.commit() await db.refresh(widget) - await sio.emit("mutate", {"entity": "widgets"}) + await emit_mutate("widgets") return widget @router.patch("/{widget_id}", response_model=WidgetRead) async def update_widget(widget_id: int, body: WidgetUpdate, db: DbSession, _current_user: CurrentUser): - result = await db.execute(select(Widget).where(Widget.id == widget_id)) - widget = result.scalar_one_or_none() - if not widget: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Widget not found") + widget = await get_or_404(db, Widget, widget_id) widget.name = body.name widget.icon = body.icon if body.attribute_id is not None: @@ -68,19 +66,16 @@ async def update_widget(widget_id: int, body: WidgetUpdate, db: DbSession, _curr widget.config = body.config await db.commit() await db.refresh(widget) - await sio.emit("mutate", {"entity": "widgets"}) + await emit_mutate("widgets") return widget @router.delete("/{widget_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_widget(widget_id: int, db: DbSession, _current_user: CurrentUser): - result = await db.execute(select(Widget).where(Widget.id == widget_id)) - widget = result.scalar_one_or_none() - if not widget: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Widget not found") + widget = await get_or_404(db, Widget, widget_id) await db.delete(widget) await db.commit() - await sio.emit("mutate", {"entity": "widgets"}) + await emit_mutate("widgets") @router.post("/reorder") @@ -93,7 +88,7 @@ async def reorder_widgets(body: WidgetReorderRequest, db: DbSession, _current_us update(Widget).where(Widget.id.in_(ids)).values(order=case(order_map, value=Widget.id, else_=Widget.order)) ) await db.commit() - await sio.emit("mutate", {"entity": "widgets"}) + await emit_mutate("widgets") return {"ok": True} @@ -108,9 +103,8 @@ async def toggle_switch( manager: HandlerManagerDep, _current_user: CurrentUser, ): - result = await db.execute(select(Widget).where(Widget.id == widget_id)) - widget = result.scalar_one_or_none() - if not widget or widget.type != "switch": + widget = await get_or_404(db, Widget, widget_id) + if widget.type != "switch": raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Switch widget not found") config = widget.config @@ -142,9 +136,8 @@ async def set_slider( manager: HandlerManagerDep, _current_user: CurrentUser, ): - result = await db.execute(select(Widget).where(Widget.id == widget_id)) - widget = result.scalar_one_or_none() - if not widget or widget.type != "slider": + widget = await get_or_404(db, Widget, widget_id) + if widget.type != "slider": raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Slider widget not found") config = widget.config @@ -175,9 +168,8 @@ async def execute_button( manager: HandlerManagerDep, _current_user: CurrentUser, ): - result = await db.execute(select(Widget).where(Widget.id == widget_id)) - widget = result.scalar_one_or_none() - if not widget or widget.type != "button": + widget = await get_or_404(db, Widget, widget_id) + if widget.type != "button": raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Button widget not found") config = widget.config @@ -205,14 +197,12 @@ async def execute_button( @router.post("/from-control", response_model=WidgetRead, status_code=status.HTTP_201_CREATED) async def create_widget_from_control(body: CreateWidgetFromControlRequest, db: DbSession, _current_user: CurrentUser): - result = await db.execute( - select(Handler) - .where(Handler.id == body.handler_id) - .options(selectinload(Handler.attributes), selectinload(Handler.actions)) + handler = await get_or_404( + db, + Handler, + body.handler_id, + options=[selectinload(Handler.attributes), selectinload(Handler.actions)], ) - handler = result.scalar_one_or_none() - if not handler: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Handler not found") cls = get_handler_class(handler.type) if not cls: @@ -271,7 +261,7 @@ async def create_widget_from_control(body: CreateWidgetFromControlRequest, db: D db.add(widget) await db.commit() await db.refresh(widget) - await sio.emit("mutate", {"entity": "widgets"}) + await emit_mutate("widgets") return widget diff --git a/backend/app/models/action.py b/backend/app/models/action.py index da8e46ed..7ff8d550 100644 --- a/backend/app/models/action.py +++ b/backend/app/models/action.py @@ -10,6 +10,8 @@ class Action(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String(150), index=True) message: Mapped[str] = mapped_column(String(500)) - handler_id: Mapped[int | None] = mapped_column(ForeignKey("handlers.id", ondelete="CASCADE"), nullable=True) + handler_id: Mapped[int | None] = mapped_column( + ForeignKey("handlers.id", ondelete="CASCADE"), nullable=True, index=True + ) handler: Mapped["Handler | None"] = relationship(back_populates="actions") # noqa: F821 diff --git a/backend/app/models/attribute.py b/backend/app/models/attribute.py index 9272e097..82f575b2 100644 --- a/backend/app/models/attribute.py +++ b/backend/app/models/attribute.py @@ -9,8 +9,8 @@ class Attribute(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String(150), index=True) - handler_id: Mapped[int] = mapped_column(ForeignKey("handlers.id")) - enabled: Mapped[bool] = mapped_column(Boolean, default=True) + handler_id: Mapped[int] = mapped_column(ForeignKey("handlers.id"), index=True) + enabled: Mapped[bool] = mapped_column(Boolean, default=True, index=True) unit: Mapped[str | None] = mapped_column(String(50), nullable=True) label: Mapped[str | None] = mapped_column(String(150), nullable=True) icon: Mapped[str | None] = mapped_column(String(100), nullable=True) diff --git a/backend/app/nodes/graph.py b/backend/app/nodes/graph.py index 2bb0c7df..5998e4a1 100644 --- a/backend/app/nodes/graph.py +++ b/backend/app/nodes/graph.py @@ -57,6 +57,34 @@ async def execute(self) -> None: await self.node.execute() +def _find_cycle(adj: dict[str, list[str]], start_nodes: set[str]) -> str | None: + """Iterative DFS cycle detection. Returns the node key involved in a cycle, or None.""" + visited: set[str] = set() + in_stack: set[str] = set() + + for start in start_nodes: + if start in visited: + continue + stack: list[tuple[str, int]] = [(start, 0)] + while stack: + key, idx = stack[-1] + if idx == 0: + visited.add(key) + in_stack.add(key) + neighbors = adj.get(key, []) + if idx < len(neighbors): + stack[-1] = (key, idx + 1) + neighbor = neighbors[idx] + if neighbor in in_stack: + return neighbor + if neighbor not in visited: + stack.append((neighbor, 0)) + else: + in_stack.discard(key) + stack.pop() + return None + + class NodeGraph: """Builds and executes a workflow graph from ReactFlow-compatible JSON.""" @@ -204,29 +232,9 @@ def _detect_cycles(self, edges_data: list[dict]) -> None: if source and target and source in self._nodes and target in self._nodes: adj.setdefault(source, []).append(target) - visited: set[str] = set() - in_stack: set[str] = set() - - for start in self._nodes: - if start in visited: - continue - stack: list[tuple[str, int]] = [(start, 0)] - while stack: - node_id, idx = stack[-1] - if idx == 0: - visited.add(node_id) - in_stack.add(node_id) - neighbors = adj.get(node_id, []) - if idx < len(neighbors): - stack[-1] = (node_id, idx + 1) - neighbor = neighbors[idx] - if neighbor in in_stack: - raise CycleDetectedError(f"Cycle detected involving node '{neighbor}'", node_id=neighbor) - if neighbor not in visited: - stack.append((neighbor, 0)) - else: - in_stack.discard(node_id) - stack.pop() + cycle_node = _find_cycle(adj, set(self._nodes)) + if cycle_node is not None: + raise CycleDetectedError(f"Cycle detected involving node '{cycle_node}'", node_id=cycle_node) async def execute_handler_listeners(self, handler_id: int) -> None: """Execute all HandlerListener nodes matching the given handler_id.""" @@ -276,29 +284,9 @@ def detect_sub_workflow_recursion( for targets in adj.values(): all_keys.update(targets) - visited: set[str] = set() - in_stack: set[str] = set() - - for start in all_keys: - if start in visited: - continue - stack: list[tuple[str, int]] = [(start, 0)] - while stack: - key, idx = stack[-1] - if idx == 0: - visited.add(key) - in_stack.add(key) - neighbors = adj.get(key, []) - if idx < len(neighbors): - stack[-1] = (key, idx + 1) - neighbor = neighbors[idx] - if neighbor in in_stack: - raise SubWorkflowRecursionError(f"Recursive sub-workflow reference detected: {neighbor}") - if neighbor not in visited: - stack.append((neighbor, 0)) - else: - in_stack.discard(key) - stack.pop() + cycle_node = _find_cycle(adj, all_keys) + if cycle_node is not None: + raise SubWorkflowRecursionError(f"Recursive sub-workflow reference detected: {cycle_node}") def _extract_sub_workflow_refs(graph_data: dict) -> list[int]: diff --git a/backend/app/nodes/logger_node.py b/backend/app/nodes/logger_node.py index 8cfc50a5..03be6e0a 100644 --- a/backend/app/nodes/logger_node.py +++ b/backend/app/nodes/logger_node.py @@ -4,7 +4,7 @@ from app.models.logging_message import LoggingMessage from app.nodes.base import AbstractNode from app.nodes.ports import event_port, value_port -from app.socketio_app import sio +from app.socketio_app import emit_mutate logger = logging.getLogger(__name__) @@ -30,4 +30,4 @@ async def execute(self) -> None: async with self.session_factory() as session: session.add(record) await session.commit() - await sio.emit("mutate", {"entity": "logs"}) + await emit_mutate("logs") diff --git a/backend/app/services/handler_manager.py b/backend/app/services/handler_manager.py index 4977f704..70e370da 100644 --- a/backend/app/services/handler_manager.py +++ b/backend/app/services/handler_manager.py @@ -18,7 +18,7 @@ from app.models.settings import Settings from app.nodes.graph import NodeGraph from app.services.attribute_tracker import AttributeTracker -from app.socketio_app import sio +from app.socketio_app import emit_mutate, sio from app.utils.linearize import linearize logger = logging.getLogger(__name__) @@ -400,7 +400,7 @@ async def start_handler(self, handler_id: int) -> None: self._last_connected_state[handler_id] = False await sio.emit("handler_status", self._status_payload(handler_id, running=True, connected=False)) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("handlers") async def stop_handler(self, handler_id: int) -> None: handler = self._handlers.get(handler_id) @@ -408,7 +408,7 @@ async def stop_handler(self, handler_id: int) -> None: await handler.stop() self._last_connected_state.pop(handler_id, None) await sio.emit("handler_status", self._status_payload(handler_id, running=False, connected=False)) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("handlers") async def remove_handler(self, handler_id: int) -> None: """Stop a handler and remove all its in-memory state (trackers, name map, etc.).""" @@ -444,7 +444,7 @@ async def restart_handler(self, handler_id: int) -> None: self._last_connected_state[handler_id] = False await sio.emit("handler_status", self._status_payload(handler_id, running=True, connected=False)) - await sio.emit("mutate", {"entity": "handlers"}) + await emit_mutate("handlers") # --- Action execution --- @@ -526,9 +526,9 @@ async def _log_action( async with self._session_factory() as session: session.add(record) await session.commit() - await sio.emit("mutate", {"entity": "logs"}) + await emit_mutate("logs") except Exception: - logger.debug("Failed to persist action log", exc_info=True) + logger.warning("Failed to persist action log", exc_info=True) # --- Value & stats seeding --- @@ -552,7 +552,7 @@ async def _seed_last_values(self, attr_ids: list[int] | None = None) -> None: ) rows = result.mappings().all() except Exception: - logger.debug("Could not seed last values from data_units, skipping") + logger.debug("Could not seed last values from data_units, skipping", exc_info=True) return for row in rows: @@ -574,7 +574,7 @@ async def _refresh_continuous_aggregate(self) -> None: await conn.execute(text("CALL refresh_continuous_aggregate('daily_stats', NULL, NULL)")) logger.info("Refreshed daily_stats continuous aggregate") except Exception: - logger.debug("Could not refresh daily_stats continuous aggregate, skipping") + logger.debug("Could not refresh daily_stats continuous aggregate, skipping", exc_info=True) async def _seed_daily_stats(self, attr_ids: list[int] | None = None) -> None: """Refresh the continuous aggregate, then seed today's stats into trackers.""" @@ -602,7 +602,7 @@ async def _seed_daily_stats(self, attr_ids: list[int] | None = None) -> None: tracker.seed_stats(date=today, min_val=row["min_value"], max_val=row["max_value"]) except Exception: # Continuous aggregate may not exist (e.g. tests with SQLite) - logger.debug("Could not query daily_stats continuous aggregate, skipping seed") + logger.debug("Could not query daily_stats continuous aggregate, skipping seed", exc_info=True) def get_daily_stats(self) -> dict[int, dict[str, float | None]]: """Return daily stats for all tracked attributes {attr_id: {min, max, stale}}.""" diff --git a/backend/app/socketio_app.py b/backend/app/socketio_app.py index 4bcbf435..50099955 100644 --- a/backend/app/socketio_app.py +++ b/backend/app/socketio_app.py @@ -13,6 +13,12 @@ sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins=[]) +async def emit_mutate(*entities: str) -> None: + """Emit mutation invalidation events for one or more entity types.""" + for entity in entities: + await sio.emit("mutate", {"entity": entity}) + + @sio.event async def connect(sid, environ, auth): token = auth.get("token") if auth else None @@ -32,13 +38,13 @@ async def connect(sid, environ, auth): await sio.save_session(sid, {"user_id": user_id}) logger.info("Socket.IO client connected: %s (user_id=%s)", sid, user_id) - await sio.emit("mutate", {"entity": "system"}) + await emit_mutate("system") @sio.event async def disconnect(sid): logger.info("Socket.IO client disconnected: %s", sid) - await sio.emit("mutate", {"entity": "system"}) + await emit_mutate("system") @sio.event diff --git a/backend/app/utils/db.py b/backend/app/utils/db.py new file mode 100644 index 00000000..243241fa --- /dev/null +++ b/backend/app/utils/db.py @@ -0,0 +1,28 @@ +from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.base import ExecutableOption + +from app.models.base import Base + + +async def get_or_404[T: Base]( + db: AsyncSession, + model: type[T], + entity_id: int, + *, + options: list[ExecutableOption] | None = None, + detail: str | None = None, +) -> T: + """Fetch a single entity by primary key or raise HTTP 404.""" + stmt = select(model).where(model.id == entity_id) # type: ignore[attr-defined] + if options: + stmt = stmt.options(*options) + result = await db.execute(stmt) + entity = result.scalar_one_or_none() + if not entity: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail or f"{model.__name__} not found", + ) + return entity diff --git a/backend/tests/test_attribute_tracker.py b/backend/tests/test_attribute_tracker.py index 82e5fa52..47544668 100644 --- a/backend/tests/test_attribute_tracker.py +++ b/backend/tests/test_attribute_tracker.py @@ -1,3 +1,5 @@ +import datetime + from app.services.attribute_tracker import AttributeTracker @@ -75,3 +77,412 @@ def test_numeric_string_filtered_by_bounds(self): tracker = AttributeTracker(1, 1, max_value=50.0) result = tracker.process_value("100") assert result.data_units == [] + + +class TestStoreOnChange: + """Test that only changed values produce data units.""" + + def test_first_value_always_stored(self): + tracker = AttributeTracker(1, 1) + result = tracker.process_value(10.0) + assert len(result.data_units) == 1 + + def test_same_value_skipped(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + result = tracker.process_value(10.0) + assert result.data_units == [] + + def test_different_value_stored(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + result = tracker.process_value(20.0) + assert len(result.data_units) == 1 + assert result.data_units[0]["value"] == 20.0 + + def test_current_value_updated_on_change(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(20.0) + assert tracker.current_value == 20.0 + + def test_current_value_unchanged_on_skip(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(10.0) + assert tracker.current_value == 10.0 + + def test_last_changed_set_on_first_value(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + assert tracker.last_changed is not None + + def test_last_changed_updated_on_change(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + first_changed = tracker.last_changed + tracker.process_value(10.0) # skip + assert tracker.last_changed == first_changed # unchanged + tracker.process_value(20.0) # change + assert tracker.last_changed > first_changed + + def test_string_store_on_change(self): + tracker = AttributeTracker(1, 1) + tracker.process_value("hello") + result = tracker.process_value("hello") + assert result.data_units == [] + result = tracker.process_value("world") + # 2 units: held "hello" at now-1µs + new "world" + assert len(result.data_units) == 2 + + +class TestStepChartHeldValue: + """Test that a held previous value is emitted at now-1µs before a change.""" + + def test_no_held_value_on_first_change(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + # Change immediately — no skip happened, so no held value + result = tracker.process_value(20.0) + assert len(result.data_units) == 1 + assert result.data_units[0]["value"] == 20.0 + + def test_held_value_after_skip(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(10.0) # skip — sets _last_value_save_skipped + result = tracker.process_value(20.0) # change + assert len(result.data_units) == 2 + # First unit is the held previous value + assert result.data_units[0]["value"] == 10.0 + # Second unit is the new value + assert result.data_units[1]["value"] == 20.0 + + def test_held_value_timestamp_offset(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(10.0) # skip + result = tracker.process_value(20.0) + held_ts = result.data_units[0]["timestamp"] + new_ts = result.data_units[1]["timestamp"] + assert new_ts - held_ts == datetime.timedelta(microseconds=1) + + def test_no_held_value_on_consecutive_changes(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(20.0) # change, no skip before + result = tracker.process_value(30.0) # change again, no skip before + assert len(result.data_units) == 1 + + def test_multiple_skips_then_change(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(10.0) # skip + tracker.process_value(10.0) # skip + tracker.process_value(10.0) # skip + result = tracker.process_value(20.0) + assert len(result.data_units) == 2 + assert result.data_units[0]["value"] == 10.0 + assert result.data_units[1]["value"] == 20.0 + + +class TestTrend: + """Test trend calculation from history deque.""" + + def test_no_history(self): + tracker = AttributeTracker(1, 1) + assert tracker.trend == 0 + + def test_single_value(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + assert tracker.trend == 0 + + def test_rising(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(20.0) + assert tracker.trend == 1 + + def test_falling(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(20.0) + tracker.process_value(10.0) + assert tracker.trend == -1 + + def test_stable(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(10.0) # skipped by store-on-change but still added to history + assert tracker.trend == 0 + + def test_trend_uses_last_two(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + tracker.process_value(20.0) + tracker.process_value(15.0) + assert tracker.trend == -1 # 20 -> 15 + + def test_non_numeric_no_history(self): + tracker = AttributeTracker(1, 1) + tracker.process_value("hello") + tracker.process_value("world") + assert tracker.trend == 0 # strings not added to history + + def test_boolean_added_to_history(self): + # bool is subclass of int, but bounds skip it — history still tracks it + tracker = AttributeTracker(1, 1) + tracker.process_value(False) + tracker.process_value(True) + assert tracker.trend == 1 # False(0) -> True(1) + + +class TestDailyStats: + """Test daily min/max tracking and date rollover.""" + + def test_first_value_sets_stats(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(25.0) + stats = tracker.daily_stats + assert stats["min"] == 25.0 + assert stats["max"] == 25.0 + assert stats["stale"] is False + + def test_min_max_update(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(25.0) + tracker.process_value(10.0) + tracker.process_value(40.0) + stats = tracker.daily_stats + assert stats["min"] == 10.0 + assert stats["max"] == 40.0 + + def test_same_value_does_not_break_stats(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(25.0) + tracker.process_value(25.0) # skipped by store-on-change but stats still track + stats = tracker.daily_stats + assert stats["min"] == 25.0 + assert stats["max"] == 25.0 + + def test_stats_reset_flag_on_date_rollover(self): + tracker = AttributeTracker(1, 1) + result = tracker.process_value(25.0) + assert result.stats_reset is True # first value always resets (no previous date) + + result = tracker.process_value(30.0) + assert result.stats_reset is False # same day + + def test_no_stats_without_value(self): + tracker = AttributeTracker(1, 1) + stats = tracker.daily_stats + assert stats["min"] is None + assert stats["max"] is None + assert stats["stale"] is False # no current value, so not stale + + def test_stale_when_has_value_but_no_stats(self): + tracker = AttributeTracker(1, 1) + tracker.seed_value(42.0) # has a value but no stats date + stats = tracker.daily_stats + assert stats["stale"] is True + + def test_string_values_dont_update_stats(self): + tracker = AttributeTracker(1, 1) + tracker.process_value("hello") + stats = tracker.daily_stats + assert stats["min"] is None + assert stats["max"] is None + + +class TestSeedValue: + """Test seeding tracker state from database on startup.""" + + def test_seed_sets_current_value(self): + tracker = AttributeTracker(1, 1) + tracker.seed_value(42.0) + assert tracker.current_value == 42.0 + + def test_seed_sets_last_changed(self): + tracker = AttributeTracker(1, 1) + ts = datetime.datetime(2026, 1, 1, tzinfo=datetime.UTC) + tracker.seed_value(42.0, last_changed=ts) + assert tracker.last_changed == ts + + def test_seed_adds_to_history(self): + tracker = AttributeTracker(1, 1) + tracker.seed_value(10.0) + tracker.seed_value(20.0) # second seed (shouldn't happen but tests history) + assert tracker.trend == 1 + + def test_seed_non_numeric_no_history(self): + tracker = AttributeTracker(1, 1) + tracker.seed_value("hello") + assert tracker.current_value == "hello" + assert tracker.trend == 0 + + def test_seed_then_process_change(self): + tracker = AttributeTracker(1, 1) + tracker.seed_value(10.0) + result = tracker.process_value(20.0) + assert len(result.data_units) == 1 + assert result.data_units[0]["value"] == 20.0 + + def test_seed_then_process_same(self): + tracker = AttributeTracker(1, 1) + tracker.seed_value(10.0) + result = tracker.process_value(10.0) + assert result.data_units == [] + + +class TestSeedStats: + """Test seeding daily stats from continuous aggregate.""" + + def test_seed_stats_sets_values(self): + tracker = AttributeTracker(1, 1) + today = datetime.date.today() + tracker.seed_stats(today, min_val=5.0, max_val=50.0) + stats = tracker.daily_stats + assert stats["min"] == 5.0 + assert stats["max"] == 50.0 + assert stats["stale"] is False + + def test_seed_stats_stale_if_old_date(self): + tracker = AttributeTracker(1, 1) + old_date = datetime.date(2020, 1, 1) + tracker.seed_stats(old_date, min_val=5.0, max_val=50.0) + stats = tracker.daily_stats + assert stats["min"] is None + assert stats["max"] is None + assert stats["stale"] is True + + def test_process_value_updates_seeded_stats(self): + tracker = AttributeTracker(1, 1) + today = datetime.date.today() + tracker.seed_stats(today, min_val=10.0, max_val=20.0) + tracker.process_value(5.0) # new min + stats = tracker.daily_stats + assert stats["min"] == 5.0 + assert stats["max"] == 20.0 + + def test_process_value_updates_seeded_max(self): + tracker = AttributeTracker(1, 1) + today = datetime.date.today() + tracker.seed_stats(today, min_val=10.0, max_val=20.0) + tracker.process_value(30.0) # new max + stats = tracker.daily_stats + assert stats["min"] == 10.0 + assert stats["max"] == 30.0 + + +class TestTypeCorrection: + """Test type mismatch correction (e.g., bool stored as float in DB).""" + + def test_type_corrected_flag(self): + tracker = AttributeTracker(1, 1) + tracker.seed_value(1.0) # seeded as float from DB + result = tracker.process_value(True) # actual value is bool True (== 1.0) + assert result.type_corrected is True + assert result.data_units == [] # no DB write needed + assert tracker.current_value is True # corrected to bool + + def test_no_correction_same_type(self): + tracker = AttributeTracker(1, 1) + tracker.process_value(10.0) + result = tracker.process_value(10.0) + assert result.type_corrected is False + + def test_float_to_bool_correction(self): + tracker = AttributeTracker(1, 1) + tracker.seed_value(0.0) + result = tracker.process_value(False) # False == 0.0 but different type + assert result.type_corrected is True + assert tracker.current_value is False + + +class TestRounding: + """Test value rounding before store-on-change comparison.""" + + def test_rounding_applied(self): + tracker = AttributeTracker(1, 1, rounding=1) + result = tracker.process_value(10.15) + assert result.data_units[0]["value"] == 10.2 + + def test_rounding_prevents_unnecessary_change(self): + tracker = AttributeTracker(1, 1, rounding=0) + tracker.process_value(10.4) # rounds to 10.0 + result = tracker.process_value(10.3) # also rounds to 10.0 + assert result.data_units == [] + + def test_rounding_zero_decimals(self): + tracker = AttributeTracker(1, 1, rounding=0) + result = tracker.process_value(10.7) + assert result.data_units[0]["value"] == 11.0 + + def test_no_rounding_when_none(self): + tracker = AttributeTracker(1, 1, rounding=None) + result = tracker.process_value(10.123456789) + assert result.data_units[0]["value"] == 10.123456789 + + def test_rounding_not_applied_to_bool(self): + tracker = AttributeTracker(1, 1, rounding=0) + tracker.process_value(True) + assert tracker.current_value is True # not rounded to 1.0 + + def test_rounding_not_applied_to_string(self): + tracker = AttributeTracker(1, 1, rounding=0) + tracker.process_value("hello") + assert tracker.current_value == "hello" + + +class TestTryNumeric: + """Test string-to-number conversion.""" + + def test_int_passthrough(self): + assert AttributeTracker._try_numeric(42) == 42 + + def test_float_passthrough(self): + assert AttributeTracker._try_numeric(3.14) == 3.14 + + def test_bool_preserved(self): + assert AttributeTracker._try_numeric(True) is True + assert AttributeTracker._try_numeric(False) is False + + def test_string_to_int(self): + assert AttributeTracker._try_numeric("42") == 42 + assert isinstance(AttributeTracker._try_numeric("42"), int) + + def test_string_to_float(self): + assert AttributeTracker._try_numeric("3.14") == 3.14 + assert isinstance(AttributeTracker._try_numeric("3.14"), float) + + def test_non_numeric_string_passthrough(self): + assert AttributeTracker._try_numeric("hello") == "hello" + + def test_none_passthrough(self): + assert AttributeTracker._try_numeric(None) is None + + def test_empty_string_passthrough(self): + assert AttributeTracker._try_numeric("") == "" + + +class TestMakeUnit: + """Test data unit dict structure.""" + + def test_numeric_value(self): + tracker = AttributeTracker(attribute_id=5, handler_id=3) + result = tracker.process_value(42.0) + unit = result.data_units[0] + assert unit["handler_id"] == 3 + assert unit["attribute_id"] == 5 + assert unit["value"] == 42.0 + assert isinstance(unit["timestamp"], datetime.datetime) + + def test_non_numeric_stored_as_zero(self): + tracker = AttributeTracker(1, 1) + result = tracker.process_value("hello") + assert result.data_units[0]["value"] == 0.0 + + def test_bool_stored_as_float(self): + tracker = AttributeTracker(1, 1) + result = tracker.process_value(True) + assert result.data_units[0]["value"] == 1.0 diff --git a/backend/tests/test_graph.py b/backend/tests/test_graph.py new file mode 100644 index 00000000..6efbaef9 --- /dev/null +++ b/backend/tests/test_graph.py @@ -0,0 +1,317 @@ +"""Tests for workflow graph utility functions. + +Covers: _find_cycle, detect_sub_workflow_recursion, _extract_sub_workflow_refs, +derive_sub_workflow_ports, _detect_port_type. +""" + +import pytest + +from app.nodes.graph import ( + PortTypeMismatchError, + SubWorkflowRecursionError, + _detect_port_type, + _extract_sub_workflow_refs, + _find_cycle, + derive_sub_workflow_ports, + detect_sub_workflow_recursion, +) + +# --------------------------------------------------------------------------- +# _find_cycle +# --------------------------------------------------------------------------- + + +class TestFindCycle: + def test_no_edges(self): + assert _find_cycle({}, {"a", "b"}) is None + + def test_linear_chain(self): + adj = {"a": ["b"], "b": ["c"]} + assert _find_cycle(adj, {"a", "b", "c"}) is None + + def test_simple_cycle(self): + adj = {"a": ["b"], "b": ["a"]} + result = _find_cycle(adj, {"a", "b"}) + assert result in ("a", "b") + + def test_self_loop(self): + adj = {"a": ["a"]} + assert _find_cycle(adj, {"a"}) == "a" + + def test_diamond_no_cycle(self): + adj = {"a": ["b", "c"], "b": ["d"], "c": ["d"]} + assert _find_cycle(adj, {"a", "b", "c", "d"}) is None + + def test_cycle_in_subgraph(self): + adj = {"a": ["b"], "b": ["c"], "c": ["b"], "d": ["e"]} + result = _find_cycle(adj, {"a", "b", "c", "d", "e"}) + assert result in ("b", "c") + + def test_disconnected_components_no_cycle(self): + adj = {"a": ["b"], "c": ["d"]} + assert _find_cycle(adj, {"a", "b", "c", "d"}) is None + + def test_empty_start_nodes(self): + adj = {"a": ["b"]} + assert _find_cycle(adj, set()) is None + + def test_three_node_cycle(self): + adj = {"a": ["b"], "b": ["c"], "c": ["a"]} + result = _find_cycle(adj, {"a", "b", "c"}) + assert result in ("a", "b", "c") + + +# --------------------------------------------------------------------------- +# _extract_sub_workflow_refs +# --------------------------------------------------------------------------- + + +class TestExtractSubWorkflowRefs: + def test_no_nodes(self): + assert _extract_sub_workflow_refs({"nodes": []}) == [] + + def test_no_sub_workflow_nodes(self): + graph = {"nodes": [{"type": "handler_listener"}, {"type": "logger"}]} + assert _extract_sub_workflow_refs(graph) == [] + + def test_extracts_ids(self): + graph = { + "nodes": [ + {"type": "sub_workflow_1"}, + {"type": "handler_listener"}, + {"type": "sub_workflow_42"}, + ] + } + assert _extract_sub_workflow_refs(graph) == [1, 42] + + def test_missing_nodes_key(self): + assert _extract_sub_workflow_refs({}) == [] + + +# --------------------------------------------------------------------------- +# detect_sub_workflow_recursion +# --------------------------------------------------------------------------- + + +class TestDetectSubWorkflowRecursion: + def test_no_sub_workflows(self): + detect_sub_workflow_recursion({}) # should not raise + + def test_no_recursion(self): + registry = { + 1: {"graph": {"nodes": [{"type": "handler_listener"}]}}, + 2: {"graph": {"nodes": [{"type": "sub_workflow_1"}]}}, + } + detect_sub_workflow_recursion(registry) # should not raise + + def test_direct_self_recursion(self): + registry = { + 1: {"graph": {"nodes": [{"type": "sub_workflow_1"}]}}, + } + with pytest.raises(SubWorkflowRecursionError): + detect_sub_workflow_recursion(registry) + + def test_indirect_recursion(self): + registry = { + 1: {"graph": {"nodes": [{"type": "sub_workflow_2"}]}}, + 2: {"graph": {"nodes": [{"type": "sub_workflow_1"}]}}, + } + with pytest.raises(SubWorkflowRecursionError): + detect_sub_workflow_recursion(registry) + + def test_three_level_recursion(self): + registry = { + 1: {"graph": {"nodes": [{"type": "sub_workflow_2"}]}}, + 2: {"graph": {"nodes": [{"type": "sub_workflow_3"}]}}, + 3: {"graph": {"nodes": [{"type": "sub_workflow_1"}]}}, + } + with pytest.raises(SubWorkflowRecursionError): + detect_sub_workflow_recursion(registry) + + def test_main_graph_references_sub_workflow(self): + registry = { + 1: {"graph": {"nodes": [{"type": "logger"}]}}, + } + main_graph = {"nodes": [{"type": "sub_workflow_1"}]} + detect_sub_workflow_recursion(registry, main_graph) # should not raise + + def test_main_graph_causes_cycle(self): + registry = { + 1: {"graph": {"nodes": [{"type": "sub_workflow_2"}]}}, + 2: {"graph": {"nodes": [{"type": "sub_workflow_1"}]}}, + } + main_graph = {"nodes": [{"type": "sub_workflow_1"}]} + with pytest.raises(SubWorkflowRecursionError): + detect_sub_workflow_recursion(registry, main_graph) + + def test_empty_graphs(self): + registry = { + 1: {"graph": {"nodes": []}}, + 2: {"graph": None}, + 3: {}, + } + detect_sub_workflow_recursion(registry) # should not raise + + +# --------------------------------------------------------------------------- +# _detect_port_type +# --------------------------------------------------------------------------- + + +class TestDetectPortType: + def test_no_edges(self): + assert _detect_port_type("n1", "source", []) is None + + def test_no_matching_edges(self): + edges = [{"source": "other", "sourceHandle": "event"}] + assert _detect_port_type("n1", "source", edges) is None + + def test_event_source(self): + edges = [{"source": "n1", "sourceHandle": "event"}] + assert _detect_port_type("n1", "source", edges) == "event" + + def test_value_source(self): + edges = [{"source": "n1", "sourceHandle": "value"}] + assert _detect_port_type("n1", "source", edges) == "value" + + def test_event_target(self): + edges = [{"target": "n1", "targetHandle": "event"}] + assert _detect_port_type("n1", "target", edges) == "event" + + def test_value_target(self): + edges = [{"target": "n1", "targetHandle": "value"}] + assert _detect_port_type("n1", "target", edges) == "value" + + def test_multiple_same_type(self): + edges = [ + {"source": "n1", "sourceHandle": "event"}, + {"source": "n1", "sourceHandle": "event"}, + ] + assert _detect_port_type("n1", "source", edges) == "event" + + def test_mixed_types_raises(self): + edges = [ + {"source": "n1", "sourceHandle": "event"}, + {"source": "n1", "sourceHandle": "value"}, + ] + with pytest.raises(PortTypeMismatchError) as exc_info: + _detect_port_type("n1", "source", edges) + assert exc_info.value.node_id == "n1" + + def test_unknown_handle_ignored(self): + edges = [{"source": "n1", "sourceHandle": "unknown"}] + assert _detect_port_type("n1", "source", edges) is None + + def test_unknown_handle_with_known(self): + edges = [ + {"source": "n1", "sourceHandle": "unknown"}, + {"source": "n1", "sourceHandle": "event"}, + ] + assert _detect_port_type("n1", "source", edges) == "event" + + +# --------------------------------------------------------------------------- +# derive_sub_workflow_ports +# --------------------------------------------------------------------------- + + +def _make_graph(nodes: list[dict], edges: list[dict] | None = None) -> dict: + return {"nodes": nodes, "edges": edges or []} + + +class TestDeriveSubWorkflowPorts: + def test_empty_graph(self): + inputs, outputs = derive_sub_workflow_ports(_make_graph([])) + assert inputs == [] + assert outputs == [] + + def test_input_port_event(self): + graph = _make_graph( + nodes=[{"id": "n1", "type": "sub_workflow_input", "data": {"name": "trigger"}}], + edges=[{"source": "n1", "sourceHandle": "event", "target": "n2", "targetHandle": "event"}], + ) + inputs, outputs = derive_sub_workflow_ports(graph) + assert inputs == [{"name": "trigger", "type": "event"}] + assert outputs == [] + + def test_output_port_value(self): + graph = _make_graph( + nodes=[{"id": "n1", "type": "sub_workflow_output", "data": {"name": "result"}}], + edges=[{"source": "n2", "sourceHandle": "value", "target": "n1", "targetHandle": "value"}], + ) + inputs, outputs = derive_sub_workflow_ports(graph) + assert inputs == [] + assert outputs == [{"name": "result", "type": "value"}] + + def test_mixed_input_output(self): + graph = _make_graph( + nodes=[ + {"id": "in1", "type": "sub_workflow_input", "data": {"name": "trigger"}}, + {"id": "out1", "type": "sub_workflow_output", "data": {"name": "result"}}, + {"id": "n3", "type": "logger", "data": {}}, + ], + edges=[ + {"source": "in1", "sourceHandle": "event", "target": "n3", "targetHandle": "event"}, + {"source": "n3", "sourceHandle": "value", "target": "out1", "targetHandle": "value"}, + ], + ) + inputs, outputs = derive_sub_workflow_ports(graph) + assert inputs == [{"name": "trigger", "type": "event"}] + assert outputs == [{"name": "result", "type": "value"}] + + def test_unconnected_input_excluded(self): + graph = _make_graph( + nodes=[{"id": "n1", "type": "sub_workflow_input", "data": {"name": "unused"}}], + edges=[], + ) + inputs, _outputs = derive_sub_workflow_ports(graph) + assert inputs == [] + + def test_unnamed_input_excluded(self): + graph = _make_graph( + nodes=[{"id": "n1", "type": "sub_workflow_input", "data": {"name": ""}}], + edges=[{"source": "n1", "sourceHandle": "event", "target": "n2", "targetHandle": "event"}], + ) + inputs, _outputs = derive_sub_workflow_ports(graph) + assert inputs == [] + + def test_no_name_key_excluded(self): + graph = _make_graph( + nodes=[{"id": "n1", "type": "sub_workflow_input", "data": {}}], + edges=[{"source": "n1", "sourceHandle": "event", "target": "n2", "targetHandle": "event"}], + ) + inputs, _outputs = derive_sub_workflow_ports(graph) + assert inputs == [] + + def test_regular_nodes_ignored(self): + graph = _make_graph( + nodes=[ + {"id": "n1", "type": "handler_listener", "data": {"name": "should_ignore"}}, + {"id": "n2", "type": "logger", "data": {"name": "also_ignore"}}, + ], + edges=[{"source": "n1", "sourceHandle": "event", "target": "n2", "targetHandle": "event"}], + ) + inputs, outputs = derive_sub_workflow_ports(graph) + assert inputs == [] + assert outputs == [] + + def test_explicit_edges_data_parameter(self): + graph = _make_graph( + nodes=[{"id": "n1", "type": "sub_workflow_input", "data": {"name": "in"}}], + edges=[], # graph has no edges + ) + # But we pass edges explicitly + explicit_edges = [{"source": "n1", "sourceHandle": "value", "target": "n2", "targetHandle": "value"}] + inputs, _outputs = derive_sub_workflow_ports(graph, edges_data=explicit_edges) + assert inputs == [{"name": "in", "type": "value"}] + + def test_mixed_types_raises(self): + graph = _make_graph( + nodes=[{"id": "n1", "type": "sub_workflow_input", "data": {"name": "bad"}}], + edges=[ + {"source": "n1", "sourceHandle": "event", "target": "n2", "targetHandle": "event"}, + {"source": "n1", "sourceHandle": "value", "target": "n3", "targetHandle": "value"}, + ], + ) + with pytest.raises(PortTypeMismatchError): + derive_sub_workflow_ports(graph) diff --git a/frontend/src/components/admin/user-management.tsx b/frontend/src/components/admin/user-management.tsx index 85cc32fc..b55ed8eb 100644 --- a/frontend/src/components/admin/user-management.tsx +++ b/frontend/src/components/admin/user-management.tsx @@ -28,6 +28,7 @@ import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; import { Switch } from "@/components/ui/switch"; +import { useConfirmDialog } from "@/hooks/use-confirm-dialog"; import { useAuth } from "@/providers/auth-provider"; const MIN_PASSWORD_LENGTH = 8; @@ -39,7 +40,7 @@ export function UserManagement() { const { data } = useListUsersApiAuthUsersGet(); const deleteUser = useDeleteUserApiAuthUsersUserIdDelete(); const [editingUser, setEditingUser] = useState(null); - const [pendingDeleteUser, setPendingDeleteUser] = useState(null); + const deleteConfirm = useConfirmDialog(); const users = (data?.data ?? []) as UserRead[]; @@ -88,7 +89,7 @@ export function UserManagement() { - )} - {onRemove && ( - - )} - - )} + void; + onRemove?: () => void; +} + +export function WidgetOverlayActions({ onEdit, onRemove }: WidgetOverlayActionsProps) { + const { t } = useTranslation(); + + if (!onEdit && !onRemove) return null; + + return ( +
+ {onEdit && ( + + )} + {onRemove && ( + + )} +
+ ); +} diff --git a/frontend/src/components/dashboard/widget-slider.tsx b/frontend/src/components/dashboard/widget-slider.tsx index f6bc9d06..c36817bf 100644 --- a/frontend/src/components/dashboard/widget-slider.tsx +++ b/frontend/src/components/dashboard/widget-slider.tsx @@ -1,10 +1,10 @@ -import { Pencil, X } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; import { useTranslation } from "react-i18next"; import { toast } from "sonner"; import type { DashboardWidget } from "@/api/generated/contWatchAPI.schemas"; import { useSetSliderApiWidgetsWidgetIdSetPost } from "@/api/generated/widgets/widgets"; import { ConfirmDialog } from "@/components/confirm-dialog"; +import { WidgetOverlayActions } from "@/components/dashboard/widget-overlay-actions"; import { Card, CardContent } from "@/components/ui/card"; import { Slider } from "@/components/ui/slider"; import { formatValue } from "@/lib/format-value"; @@ -184,36 +184,7 @@ export function WidgetSlider({ widget: slider, status = "online", onRemove, onEd - {(onEdit || onRemove) && ( -
- {onEdit && ( - - )} - {onRemove && ( - - )} -
- )} + { diff --git a/frontend/src/components/dashboard/widget-sparkline.tsx b/frontend/src/components/dashboard/widget-sparkline.tsx index 5b80b8ae..a3b83f81 100644 --- a/frontend/src/components/dashboard/widget-sparkline.tsx +++ b/frontend/src/components/dashboard/widget-sparkline.tsx @@ -1,9 +1,9 @@ import { format, subDays } from "date-fns"; -import { Pencil, X } from "lucide-react"; import { useMemo } from "react"; import { useTranslation } from "react-i18next"; import type { DailyStatRead, DashboardWidget } from "@/api/generated/contWatchAPI.schemas"; import { useListDataStatsApiDataStatsGet } from "@/api/generated/data-stats/data-stats"; +import { WidgetOverlayActions } from "@/components/dashboard/widget-overlay-actions"; import { Card, CardContent } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; import { formatStat } from "@/lib/format-value"; @@ -170,36 +170,7 @@ export function WidgetSparkline({ )} - {(onEdit || onRemove) && ( -
- {onEdit && ( - - )} - {onRemove && ( - - )} -
- )} + ); diff --git a/frontend/src/components/dashboard/widget-switch.tsx b/frontend/src/components/dashboard/widget-switch.tsx index b76e445f..c022905b 100644 --- a/frontend/src/components/dashboard/widget-switch.tsx +++ b/frontend/src/components/dashboard/widget-switch.tsx @@ -1,12 +1,13 @@ -import { History, Pencil, X } from "lucide-react"; -import { useState } from "react"; +import { History } from "lucide-react"; import { useTranslation } from "react-i18next"; import { toast } from "sonner"; import type { DashboardWidget } from "@/api/generated/contWatchAPI.schemas"; import { useToggleSwitchApiWidgetsWidgetIdTogglePost } from "@/api/generated/widgets/widgets"; import { ConfirmDialog } from "@/components/confirm-dialog"; +import { WidgetOverlayActions } from "@/components/dashboard/widget-overlay-actions"; import { Card, CardContent } from "@/components/ui/card"; import { Switch } from "@/components/ui/switch"; +import { useConfirmDialog } from "@/hooks/use-confirm-dialog"; import { localizeAttributeLabel } from "@/lib/localize-attribute"; import { boldName } from "@/lib/utils"; import type { WidgetStatus } from "@/lib/widget-status"; @@ -33,7 +34,7 @@ export function WidgetSwitch({ widget: switch_, status = "online", onRemove, onE const toggle = useToggleSwitchApiWidgetsWidgetIdTogglePost({ mutation: { meta: { skipGlobalErrorToast: true } }, }); - const [confirmOpen, setConfirmOpen] = useState(false); + const toggleConfirm = useConfirmDialog(); const currentValue = liveVal?.value ?? rawVal?.value ?? switch_.value; @@ -55,7 +56,7 @@ export function WidgetSwitch({ widget: switch_, status = "online", onRemove, onE toggle.mutate( { widgetId: switch_.id, data: { value: !isOn } }, { - onSettled: () => setConfirmOpen(false), + onSettled: () => toggleConfirm.close(), onSuccess: () => name && toast.success(boldName(t, "toast.actionExecuted", name)), onError: () => toast.error(boldName(t, "toast.actionFailed", name ?? "")), }, @@ -64,7 +65,7 @@ export function WidgetSwitch({ widget: switch_, status = "online", onRemove, onE function handleToggle() { if (switch_.confirm_actions) { - setConfirmOpen(true); + toggleConfirm.open(); } else { doToggle(); } @@ -134,39 +135,9 @@ export function WidgetSwitch({ widget: switch_, status = "online", onRemove, onE - {(onEdit || onRemove) && ( -
- {onEdit && ( - - )} - {onRemove && ( - - )} -
- )} + { diff --git a/frontend/src/components/dashboard/widget-tile.tsx b/frontend/src/components/dashboard/widget-tile.tsx index a3cd8a65..521a469f 100644 --- a/frontend/src/components/dashboard/widget-tile.tsx +++ b/frontend/src/components/dashboard/widget-tile.tsx @@ -1,6 +1,7 @@ -import { History, Minus, Pencil, TrendingDown, TrendingUp, X } from "lucide-react"; +import { History, Minus, TrendingDown, TrendingUp } from "lucide-react"; import { useTranslation } from "react-i18next"; import type { DashboardWidget } from "@/api/generated/contWatchAPI.schemas"; +import { WidgetOverlayActions } from "@/components/dashboard/widget-overlay-actions"; import { Card, CardContent } from "@/components/ui/card"; import { formatStat, formatValue } from "@/lib/format-value"; import { localizeAttributeLabel } from "@/lib/localize-attribute"; @@ -127,36 +128,7 @@ export function WidgetTile({ widget: tile, status = "online", onRemove, onEdit, )} - {(onEdit || onRemove) && ( -
- {onEdit && ( - - )} - {onRemove && ( - - )} -
- )} + ); diff --git a/frontend/src/components/handlers/action-edit-dialog.tsx b/frontend/src/components/handlers/action-edit-dialog.tsx index 56d30758..0b7aeab4 100644 --- a/frontend/src/components/handlers/action-edit-dialog.tsx +++ b/frontend/src/components/handlers/action-edit-dialog.tsx @@ -13,6 +13,7 @@ import { Button } from "@/components/ui/button"; import { Dialog, DialogContent, DialogFooter, DialogHeader, DialogTitle } from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import { useConfirmDialog } from "@/hooks/use-confirm-dialog"; import { ActionMessageEditor } from "./action-message-editor"; interface ActionEditDialogProps { @@ -25,7 +26,7 @@ export function ActionEditDialog({ action, onClose }: ActionEditDialogProps) { const queryClient = useQueryClient(); const [name, setName] = useState(""); const [message, setMessage] = useState(""); - const [confirmDeleteOpen, setConfirmDeleteOpen] = useState(false); + const deleteConfirm = useConfirmDialog(); const updateAction = useUpdateActionApiActionsActionIdPatch(); const deleteAction = useDeleteActionApiActionsActionIdDelete(); @@ -83,7 +84,7 @@ export function ActionEditDialog({ action, onClose }: ActionEditDialogProps) { type="button" variant="destructive" disabled={deleteAction.isPending} - onClick={() => setConfirmDeleteOpen(true)} + onClick={() => deleteConfirm.open()} > {t("common.delete")} @@ -94,8 +95,7 @@ export function ActionEditDialog({ action, onClose }: ActionEditDialogProps) { { - setConfirmDeleteOpen(false); + deleteConfirm.close(); queryClient.invalidateQueries({ queryKey: getListHandlersApiHandlersGetQueryKey(), }); diff --git a/frontend/src/components/handlers/action-param-dialog.tsx b/frontend/src/components/handlers/action-param-dialog.tsx index 4c71e337..05db870a 100644 --- a/frontend/src/components/handlers/action-param-dialog.tsx +++ b/frontend/src/components/handlers/action-param-dialog.tsx @@ -16,6 +16,7 @@ import { Button } from "@/components/ui/button"; import { Label } from "@/components/ui/label"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Slider } from "@/components/ui/slider"; +import { useConfirmDialog } from "@/hooks/use-confirm-dialog"; import { boldName } from "@/lib/utils"; import { useLiveValuesStore } from "@/stores/live-values"; @@ -77,7 +78,7 @@ export function ActionParamPopover({ action, params, handler, children }: Action } }, [open, paramAttrIds, params]); - const [confirmOpen, setConfirmOpen] = useState(false); + const executeConfirm = useConfirmDialog(); const localizedName = t(`knownActions.${action.name.replaceAll(" ", "_")}`, action.name); function doExecute() { @@ -96,7 +97,7 @@ export function ActionParamPopover({ action, params, handler, children }: Action function handleSubmit() { if (handler.confirm_actions) { setOpen(false); - setConfirmOpen(true); + executeConfirm.open(); } else { doExecute(); } @@ -177,8 +178,7 @@ export function ActionParamPopover({ action, params, handler, children }: Action setConfirmDeleteOpen(true)} + onClick={() => deleteConfirm.open()} > {t("common.delete")} @@ -167,8 +168,7 @@ export function AttributeEditDialog({ attribute, onClose }: AttributeEditDialogP { - setConfirmDeleteOpen(false); + deleteConfirm.close(); queryClient.invalidateQueries({ queryKey: getListAttributesApiAttributesGetQueryKey(), }); diff --git a/frontend/src/components/handlers/handler-controls.tsx b/frontend/src/components/handlers/handler-controls.tsx index aff0eae1..feb09935 100644 --- a/frontend/src/components/handlers/handler-controls.tsx +++ b/frontend/src/components/handlers/handler-controls.tsx @@ -8,6 +8,7 @@ import { ConfirmDialog } from "@/components/confirm-dialog"; import { SafeIcon } from "@/components/safe-icon"; import { Slider } from "@/components/ui/slider"; import { Switch } from "@/components/ui/switch"; +import { useConfirmDialog } from "@/hooks/use-confirm-dialog"; import { boldName, cn } from "@/lib/utils"; import { useLiveValuesStore } from "@/stores/live-values"; @@ -67,7 +68,7 @@ function ControlSwitch({ const executeAction = useExecuteActionApiActionsActionIdExecutePost({ mutation: { meta: { skipGlobalErrorToast: true } }, }); - const [pendingChecked, setPendingChecked] = useState(null); + const toggleConfirm = useConfirmDialog(); const liveValue = useLiveValuesStore((s) => control.attribute_id ? s.values[control.attribute_id] : undefined, ); @@ -89,7 +90,7 @@ function ControlSwitch({ executeAction.mutate( { actionId, data: null }, { - onSettled: () => setPendingChecked(null), + onSettled: () => toggleConfirm.close(), onSuccess: () => toast.success(boldName(t, "toast.actionExecuted", actionName ?? "")), onError: () => toast.error(boldName(t, "toast.actionFailed", actionName ?? "")), }, @@ -98,7 +99,7 @@ function ControlSwitch({ function handleToggle(checked: boolean) { if (confirmActions) { - setPendingChecked(checked); + toggleConfirm.open(checked); } else { doToggle(checked); } @@ -152,13 +153,14 @@ function ControlSwitch({ )} !open && setPendingChecked(null)} + {...toggleConfirm.dialogProps} onConfirm={() => { - if (pendingChecked !== null) doToggle(pendingChecked); + if (toggleConfirm.pending !== null) doToggle(toggleConfirm.pending); }} description={t("confirm.executeAction", { - action: localizeAction(pendingChecked ? control.action_on_name : control.action_off_name), + action: localizeAction( + toggleConfirm.pending ? control.action_on_name : control.action_off_name, + ), device: handlerLabel, })} /> diff --git a/frontend/src/components/handlers/handler-detail.tsx b/frontend/src/components/handlers/handler-detail.tsx index 40bbf7d1..4bed36cf 100644 --- a/frontend/src/components/handlers/handler-detail.tsx +++ b/frontend/src/components/handlers/handler-detail.tsx @@ -22,6 +22,7 @@ import type { ResolvedControl, } from "@/api/generated/contWatchAPI.schemas"; import { + getListHandlersApiHandlersGetQueryKey, useAvailableAttributesApiHandlersHandlerIdAvailableAttributesGet, useDeleteHandlerApiHandlersHandlerIdDelete, useHandlerControlsApiHandlersHandlerIdControlsGet, @@ -34,6 +35,8 @@ import { SafeIcon } from "@/components/safe-icon"; import { Button } from "@/components/ui/button"; import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; +import { useConfirmDialog } from "@/hooks/use-confirm-dialog"; +import { useOptimisticReorder } from "@/hooks/use-optimistic-reorder"; import { dateFnsLocales } from "@/lib/date-locale"; import { formatValue } from "@/lib/format-value"; import { localizeAttributeLabel } from "@/lib/localize-attribute"; @@ -105,7 +108,8 @@ function HandlerInfo({ onClose: () => void; }) { const { t, i18n } = useTranslation(); - const [confirmDeleteOpen, setConfirmDeleteOpen] = useState(false); + const queryClient = useQueryClient(); + const deleteConfirm = useConfirmDialog(); const status = useHandlerStatusStore((s) => s.statuses[handler.id]); const { data: typesData } = useListHandlerTypesApiHandlersTypesGet(); const startHandler = useStartHandlerApiHandlersHandlerIdStartPost(); @@ -222,7 +226,14 @@ function HandlerInfo({ onClick={() => stopHandler.mutate( { handlerId: handler.id }, - { onSuccess: () => toast.success(t("toast.handlerStopped")) }, + { + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: getListHandlersApiHandlersGetQueryKey(), + }); + toast.success(t("toast.handlerStopped")); + }, + }, ) } disabled={stopHandler.isPending} @@ -238,7 +249,14 @@ function HandlerInfo({ onClick={() => startHandler.mutate( { handlerId: handler.id }, - { onSuccess: () => toast.success(t("toast.handlerStarted")) }, + { + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: getListHandlersApiHandlersGetQueryKey(), + }); + toast.success(t("toast.handlerStarted")); + }, + }, ) } disabled={startHandler.isPending} @@ -252,7 +270,7 @@ function HandlerInfo({ + + + ); +} diff --git a/frontend/src/components/workflow/workflow-editor.tsx b/frontend/src/components/workflow/workflow-editor.tsx index b077302f..86501038 100644 --- a/frontend/src/components/workflow/workflow-editor.tsx +++ b/frontend/src/components/workflow/workflow-editor.tsx @@ -555,7 +555,7 @@ export const WorkflowEditor = forwardRef }, ); } - }, [isSubWorkflow, subWorkflowId, queryClient, onSaveError]); // eslint-disable-line react-hooks/exhaustive-deps + }, [isSubWorkflow, subWorkflowId, queryClient, onSaveError]); const clearErrorNodes = useCallback(() => { if (nodesRef.current.some((n) => n.data._error)) { diff --git a/frontend/src/hooks/use-confirm-dialog.ts b/frontend/src/hooks/use-confirm-dialog.ts new file mode 100644 index 00000000..520dcf16 --- /dev/null +++ b/frontend/src/hooks/use-confirm-dialog.ts @@ -0,0 +1,44 @@ +import { useCallback, useState } from "react"; + +interface ConfirmDialogState { + /** The pending value, or null when closed */ + pending: T | null; + /** Whether the dialog is open */ + isOpen: boolean; + /** Open the dialog, optionally with a value (defaults to `true` for boolean usage) */ + open: (value?: T) => void; + /** Close the dialog */ + close: () => void; + /** Props to spread onto ConfirmDialog: `open` and `onOpenChange` */ + dialogProps: { + open: boolean; + onOpenChange: (open: boolean) => void; + }; +} + +export function useConfirmDialog(): ConfirmDialogState { + const [pending, setPending] = useState(null); + + const open = useCallback((value?: T) => { + setPending((value ?? true) as T); + }, []); + + const close = useCallback(() => { + setPending(null); + }, []); + + const onOpenChange = useCallback((isOpen: boolean) => { + if (!isOpen) setPending(null); + }, []); + + return { + pending, + isOpen: pending !== null, + open, + close, + dialogProps: { + open: pending !== null, + onOpenChange, + }, + }; +} diff --git a/frontend/src/hooks/use-optimistic-reorder.ts b/frontend/src/hooks/use-optimistic-reorder.ts new file mode 100644 index 00000000..59273909 --- /dev/null +++ b/frontend/src/hooks/use-optimistic-reorder.ts @@ -0,0 +1,33 @@ +import type { QueryKey } from "@tanstack/react-query"; +import { useQueryClient } from "@tanstack/react-query"; +import { useCallback } from "react"; + +/** + * Hook that builds a reorder handler with optimistic cache update. + * + * Handles the common pattern of: + * 1. Building `{ id, order }` pairs from reordered items + * 2. Optimistically updating the React Query cache + * 3. Firing the reorder mutation + */ +export function useOptimisticReorder( + queryKey: QueryKey, + mutate: (variables: { data: { items: { id: number; order: number }[] } }) => void, +): (reordered: T[]) => void { + const queryClient = useQueryClient(); + + return useCallback( + (reordered: T[]) => { + const items = reordered.map((item, i) => ({ id: item.id, order: i })); + + queryClient.cancelQueries({ queryKey }); + queryClient.setQueryData(queryKey, (old: unknown) => { + if (!old || typeof old !== "object" || !("data" in old)) return old; + return { ...old, data: reordered.map((item, i) => ({ ...item, order: i })) }; + }); + + mutate({ data: { items } }); + }, + [queryClient, queryKey, mutate], + ); +} diff --git a/frontend/src/i18n/locales/cs/translation.json b/frontend/src/i18n/locales/cs/translation.json index 4bdf3d83..2d47e381 100644 --- a/frontend/src/i18n/locales/cs/translation.json +++ b/frontend/src/i18n/locales/cs/translation.json @@ -570,6 +570,11 @@ "registerAttributeTitle": "Registrovat atribut", "registerAttribute": "Registrovat \"{{name}}\" jako sledovaný atribut?" }, + "error": { + "title": "Něco se pokazilo", + "description": "Došlo k neočekávané chybě. Můžete to zkusit znovu nebo se vrátit na přehled.", + "retry": "Zkusit znovu" + }, "notFound": { "title": "Stránka nenalezena", "description": "Stránka, kterou hledáte, neexistuje nebo byla přesunuta.", diff --git a/frontend/src/i18n/locales/en/translation.json b/frontend/src/i18n/locales/en/translation.json index 92ae3847..cc768d0e 100644 --- a/frontend/src/i18n/locales/en/translation.json +++ b/frontend/src/i18n/locales/en/translation.json @@ -570,6 +570,11 @@ "registerAttributeTitle": "Register attribute", "registerAttribute": "Register \"{{name}}\" as a tracked attribute?" }, + "error": { + "title": "Something went wrong", + "description": "An unexpected error occurred. You can try again or go back to the dashboard.", + "retry": "Try again" + }, "notFound": { "title": "Page not found", "description": "The page you're looking for doesn't exist or has been moved.", diff --git a/frontend/src/providers/query-provider.tsx b/frontend/src/providers/query-provider.tsx index eb9618f5..9226f50e 100644 --- a/frontend/src/providers/query-provider.tsx +++ b/frontend/src/providers/query-provider.tsx @@ -34,12 +34,6 @@ export function QueryProvider({ children }: { children: ReactNode }) { }, }, mutationCache: new MutationCache({ - onSuccess: () => { - // Invalidate all queries on any mutation success. - // Only currently observed (mounted) queries refetch; - // unobserved ones just get marked stale for next use. - queryClient.invalidateQueries(); - }, onError: (error, _variables, _context, mutation) => { // Skip global toast when the mutation handles errors locally if (mutation.meta?.skipGlobalErrorToast) return; diff --git a/frontend/src/routes/__root.tsx b/frontend/src/routes/__root.tsx index c4a528d5..5b762bbc 100644 --- a/frontend/src/routes/__root.tsx +++ b/frontend/src/routes/__root.tsx @@ -2,6 +2,7 @@ import { createRootRouteWithContext, Outlet } from "@tanstack/react-router"; import { TanStackRouterDevtools } from "@tanstack/react-router-devtools"; import { Toaster } from "sonner"; import { NotFound } from "@/components/not-found"; +import { RouteError } from "@/components/route-error"; import { useSettingsStore } from "@/stores/settings"; interface RouterContext { @@ -26,4 +27,5 @@ function RootComponent() { export const Route = createRootRouteWithContext()({ component: RootComponent, notFoundComponent: NotFound, + errorComponent: RouteError, }); diff --git a/frontend/src/routes/_authenticated.tsx b/frontend/src/routes/_authenticated.tsx index d13aa6b8..7c4966fc 100644 --- a/frontend/src/routes/_authenticated.tsx +++ b/frontend/src/routes/_authenticated.tsx @@ -1,5 +1,6 @@ import { createFileRoute, Outlet, redirect } from "@tanstack/react-router"; import { AppLayout } from "@/components/layout/app-layout"; +import { RouteError } from "@/components/route-error"; import { useSeedHandlerStatuses } from "@/hooks/use-seed-handler-statuses"; import { useSeedLiveValues } from "@/hooks/use-seed-live-values"; import { useSocketConnection } from "@/providers/socket-provider"; @@ -18,6 +19,7 @@ export const Route = createFileRoute("/_authenticated")({ } }, component: AuthenticatedLayout, + errorComponent: RouteError, }); function AuthenticatedLayout() {