Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion bot/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def __init__(self, bot: Bot, update: Update) -> None:
else None
)
self.variables = Variables(
bot=bot, chat=chat, user=user, message=update.effective_message
bot=bot,
chat=chat,
user=user,
message=update.effective_message,
user_storage=self.user_storage,
)

def copy(self) -> HandlerContext:
Expand Down
14 changes: 11 additions & 3 deletions bot/handlers/temporary_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from service.models import Connection, TemporaryVariable

from ..context import HandlerContext
from ..storage import Storage
from ..utils.variables import replace_text_variables
from .base import BaseHandler

Expand All @@ -12,9 +13,16 @@
class TemporaryVariableHandler(BaseHandler[TemporaryVariable]):
async def handle(
self, update: Update, variable: TemporaryVariable, context: HandlerContext
) -> list[Connection]:
temporary: dict[str, Any] = context.variables.store.setdefault('TEMPORARY', {})
temporary[variable.name] = await replace_text_variables(
) -> list[Connection] | None:
user_storage: Storage | None = context.user_storage

if not user_storage:
return None

variables: dict[str, Any] = await user_storage.get('temporary_variables', {})
variables[variable.name] = await replace_text_variables(
variable.value, context.variables
)
await user_storage.set('temporary_variables', variables)

return variable.source_connections
10 changes: 10 additions & 0 deletions bot/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from service.models import DatabaseRecord, Variable

from .storage import Storage
from .utils.html import process_html_text

from typing import TYPE_CHECKING, Any
Expand All @@ -24,8 +25,11 @@ def __init__(
chat: Chat | None = None,
user: User | None = None,
message: Message | None = None,
user_storage: Storage | None = None,
):
self.bot = bot
self._user_storage = user_storage

self.store: dict[str, Any] = {}
self.system_store: dict[str, Any] = {
'BOT_ID': bot.me.id,
Expand Down Expand Up @@ -145,6 +149,12 @@ async def get(self, key: str) -> Any | None:
return self.system_store.get(nested_key)
elif prefix == 'USER':
return await self._resolve_user_value(nested_key)
elif (
prefix == 'TEMPORARY'
and self._user_storage
and (variables := await self._user_storage.get('temporary_variables'))
):
return self._resolve_value(variables, nested_key)
elif prefix == 'DATABASE':
return await self._resolve_database_value(nested_key)
elif (value := self.store.get(prefix)) and isinstance(
Expand Down
Loading