diff --git a/ferris/plugins/__init__.py b/ferris/plugins/__init__.py deleted file mode 100644 index 9f215bd..0000000 --- a/ferris/plugins/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import commands diff --git a/ferris/plugins/commands/__init__.py b/ferris/plugins/commands/__init__.py index 78af14e..a671124 100644 --- a/ferris/plugins/commands/__init__.py +++ b/ferris/plugins/commands/__init__.py @@ -1,2 +1,15 @@ -from .core import Bot, CommandSink +from . import errors, parser +from .core import CaseInsensitiveDict, CommandSink, Bot +from .converters import * +from .errors import * from .models import Command, Context +from .parser import ( + Argument, + ConsumeType, + Converter, + Greedy, + Not, + Quotes, + StringReader, + converter, +) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py new file mode 100644 index 0000000..452a447 --- /dev/null +++ b/ferris/plugins/commands/converters.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ...errors import NotFound +from ...utils import INVITE_REGEX, find +from .errors import BadArgument + + +__all__ = ('ChannelConverter', + 'GuildConverter', + 'UserConverter', + 'MemberConverter', + 'MessageConverter', + 'RoleConverter', + 'InviteConverter') + +if TYPE_CHECKING: + from ... import Channel, Guild, Invite, Member, Message, Role, User + from .models import Context + from .parser import Converter + + +class ChannelConverter(Converter[Channel]): + async def convert(self, ctx: Context, argument: str): + c = None + + if argument.isdigit(): + id_ = int(argument) + + c = ctx.bot.get_channel(id_) + + if not c: + try: + c = await ctx.bot.fetch_channel(id_) + except NotFound: + if not c: + raise BadArgument(f'Channel {id_!r} not found') + else: + raise BadArgument(f'Argument must be an id.') + + return c + + +class GuildConverter(Converter[Guild]): + async def convert(self, ctx: Context, argument: str): + g = None + + if argument.isdigit(): + id_ = int(argument) + + g = ctx.bot.get_guild(id_) + + if not g: + try: + g = await ctx.bot.fetch_guild(id_) + except NotFound: + if not g: + raise BadArgument(f'Guild {id_!r} not found') + else: + raise BadArgument(f'Argument must be an id.') + + return g + + +class UserConverter(Converter[User]): + async def convert(self, ctx: Context, argument: str): + u = None + + if argument.isdigit(): + id_ = int(argument) + + u = ctx.bot.get_user(id_) + + if not u: + try: + u = await ctx.bot.fetch_user(id_) + except NotFound: + if not u: + raise BadArgument(f'User {id_!r} not found') + else: + raise BadArgument(f'Argument must be an id.') + + return u + + +class MemberConverter(Converter[Member]): + async def convert(self, ctx: Context, argument: str): + m = None + + if argument.isdigit(): + id_ = int(argument) + + m = ctx.guild.get_member(id_) + + if not m: + try: + m = await ctx.guild.fetch_member(id_) + except NotFound: + raise BadArgument(f'Member {id_!r} not found') + else: + raise BadArgument('Argument must be an id.') + + return m + + +class MessageConverter(Converter[Message]): + async def convert(self, ctx: Context, argument: str): + m = None + + if argument.isdigit(): + id_ = int(argument) + + m = ctx.bot.get_message(id_) + + if not m: + try: + m = await ctx.bot.fetch_message(id_) + except NotFound: + raise BadArgument(f'Message {id_!r} not found') + else: # TODO: Convert from message url after webclient rewrite + raise BadArgument('Argument must be an id.') + + return m + + +class RoleConverter(Converter[Role]): + async def convert(self, ctx: Context, argument: str): + r = None + + if argument.isdigit(): + id_ = int(argument) + + r = ctx.guild.get_role(id_) + + if not r: + try: + r = await ctx.guild.fetch_role(id_) + except NotFound: + raise BadArgument(f'Role {id_!r} not found') + else: + r = find(lambda r: r.name == argument, ctx.guild.roles) + + if not r: + raise BadArgument('Argument must be role id or name.') + + return r + + +class InviteConverter(Converter[Invite]): + async def convert(self, ctx: Context, argument: str): + if match := INVITE_REGEX.match(argument): + argument = match.group(4) + + try: + return await ctx.bot.fetch_invite(argument) + except NotFound: + raise BadArgument('Argument must be a valid invite url or code.') diff --git a/ferris/plugins/commands/core.py b/ferris/plugins/commands/core.py index 6647319..086ba17 100644 --- a/ferris/plugins/commands/core.py +++ b/ferris/plugins/commands/core.py @@ -1,70 +1,399 @@ from __future__ import annotations -import asyncio -from typing import Awaitable, Callable, Dict, Generator, TYPE_CHECKING, Sequence, Union +import functools -from ferris.client import Client -from .models import Command +from typing import ( + Any, + Awaitable, + Callable, + Collection, + Dict, + Generator, + List, + Optional, + overload, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) + +from .errors import CommandBasedError, CommandNotFound +from .models import Command, Context +from .parser import StringReader +from ...message import Message +from ...client import Client +from ...utils import ensure_async + +V = TypeVar('V') if TYPE_CHECKING: - from ferris.message import Message + from .models import CommandCallback + + DefaultT = TypeVar('DefaultT') + + BasePrefixT = Union[str, Collection[str]] + FunctionPrefixT = Callable[['Bot', Message], Awaitable[BasePrefixT]] + PrefixT = Union[BasePrefixT, FunctionPrefixT] + +__all__ = ( + 'CaseInsensitiveDict', + 'CommandSink', + 'Bot', +) + + +class CaseInsensitiveDict(Dict[str, V]): + """Represents a case-insensitive dictionary.""" + + def __getitem__(self, key: str) -> V: + return super().__getitem__(key.casefold()) + + def __setitem__(self, key: str, value: V) -> None: + super().__setitem__(key.casefold(), value) + + def __delitem__(self, key: str) -> None: + return super().__delitem__(key.casefold()) - _BasePrefixT = Union[str, Sequence[str]] - PrefixT = Union[ - _BasePrefixT, - Callable[['Bot', Message], Union[_BasePrefixT, Awaitable[_BasePrefixT]]], - ] + def __contains__(self, key: str) -> bool: + return super().__contains__(key.casefold()) + + def get(self, key: str, default: DefaultT = None) -> Union[V, DefaultT]: + return super().get(key.casefold(), default) + + def pop(self, key: str, default: DefaultT = None) -> Union[V, DefaultT]: + return super().pop(key.casefold(), default) + + get.__doc__ = dict.get.__doc__ + pop.__doc__ = dict.pop.__doc__ class CommandSink: """Represents a sink of commands. - This will both mixin to :class:`.Bot` and :class:`.CommandGroup`. + This will both mixin to :class:`~.Bot` and :class:`~.CommandGroup`. Attributes ---------- - mapping: Dict[str, :class:`.Command`] + command_mapping: Dict[str, :class:`.Command`] A full mapping of command names to commands. """ - def __init__(self) -> None: - self.mapping: Dict[str, Command] = {} + if TYPE_CHECKING: + command_mapping: Union[Dict[str, Command], CaseInsensitiveDict[Command]] - def walk_commands(self) -> Generator[Command]: + def __init__(self, *, case_insensitive: bool = False) -> None: + mapping_factory = CaseInsensitiveDict if case_insensitive else dict + self.command_mapping = mapping_factory() + + def walk_commands(self) -> Generator[Command, None, None]: """Returns a generator that walks through all of the commands this sink holds. Returns ------- - Generator[:class:`.Command`] + Generator[:class:`.Command`, None, None] """ seen = set() - for command in self.mapping.values(): + + for command in self.command_mapping.values(): if command not in seen: seen.add(command) yield command + def get_command(self, name: str, /, default: DefaultT = None) -> Union[Command, DefaultT]: + """Tries to get a command by it's name. + Aliases are supported. + + Parameters + ---------- + name: str + The name of the command you want to lookup. + default + What to return instead if the command was not found. + Defaults to ``None``. + + Returns + ------- + Optional[:class:`~.Command`] + The command found. + """ + return self.command_mapping.get(name, default) + + @overload + def command( + self, + name: str, + *, + alias: str = None, + brief: str = None, + description: str = None, + usage: str = None, + cls: Type[Command] = Command, + **kwargs + ) -> Callable[[CommandCallback], Command]: + ... + + @overload + def command( + self, + name: str, + *, + aliases: Collection[str] = None, + brief: str = None, + description: str = None, + usage: str = None, + cls: Type[Command] = Command, + **kwargs + ) -> Callable[[CommandCallback], Command]: + ... + + def command( + self, + name: str, + *, + alias: str = None, + aliases: Collection[str] = None, + brief: str = None, + description: str = None, + usage: str = None, + cls: Type[Command] = Command, + **kwargs + ) -> Callable[[CommandCallback], Command]: + """Returns a decorator that adds a command to this command sink.""" + + if alias and aliases: + raise ValueError('Only one of alias or aliases can be set.') + + if alias: + aliases = [alias] + elif not aliases: + aliases = [] + + def decorator(callback: CommandCallback, /) -> Command: + command = cls( + callback, + name=str(name), + aliases=aliases, + brief=brief, + description=description, + usage=usage, + **kwargs + ) + + self.command_mapping[name] = command + for alias in aliases: + self.command_mapping[alias] = command + + return command + + return decorator + @property - def commands(self) -> None: + def commands(self) -> List[Command]: """List[:class:`.Command`]: A list of the commands this sink holds.""" return list(self.walk_commands()) + def remove_command(self, name: str) -> Optional[Command]: + """ + Remove a command or alias from this command sink. + + Parameters + ----------- + name: :class:`str` + The name of the command or alias to remove. + + Returns + -------- + Optional[:class:`~.Command`] + The command that was removed. + If the name is invalid, ``None`` is returned instead. + """ + command = self.command_mapping.pop(name) + + if command: + if name in command.aliases: # This is an alias, so don't remove the command, only this alias. + return command + + for alias in command.aliases: + self.command_mapping.pop(alias) + + return command + class Bot(Client, CommandSink): - """Represents a client connection to Discord with extra command handling support. + """Represents a bot with extra command handling support. Parameters ---------- + prefix + The prefix this bot will listen for. This is required. + case_insensitive: bool + Whether or not commands should be case-insensitive. loop: Optional[:class:`asyncio.AbstractEventLoop`] The event loop to use for the client. If not passed, then the default event loop is used. - prefix - The prefix the bot will listen for. This is required. max_messages_count: Optional[int] The maximum number of messages to store in the internal message buffer. Defaults to ``1000``. + + max_heartbeat_timeout: Optional[int] + The maximum timeout in seconds between sending a heartbeat to the server. + If heartbeat took longer than this timeout, the client will attempt to reconnect. """ - def __init__(self, prefix: PrefixT, **kwargs) -> None: + def __init__( + self, + prefix: PrefixT, + *, + case_insensitive: bool = False, + prefix_case_insensitive: bool = False, + strip_after_prefix: bool = False, + **kwargs + ) -> None: super().__init__(**kwargs) - CommandSink.__init__(self) + CommandSink.__init__(self, case_insensitive=case_insensitive) + + self.prefix: PrefixT = self._sanitize_prefix(prefix, case_insensitive=prefix_case_insensitive) + self.strip_after_prefix: bool = bool(strip_after_prefix) + + self._prefix_case_insensitive: bool = bool(prefix_case_insensitive) + self._default_case_insensitive: bool = bool(case_insensitive) + + @staticmethod + def _sanitize_prefix(prefix: Any, *, case_insensitive: bool = False, allow_callable: bool = True) -> PrefixT: + if prefix is None: + return '' + + if isinstance(prefix, str): + if case_insensitive: + return prefix.casefold() + + return prefix + + if isinstance(prefix, Collection): + try: + invalid = next(filter(lambda item: not isinstance(item, str), prefix)) + except StopIteration: + pass + else: + raise TypeError(f'Prefix collection must only consist of strings, not {type(invalid)!r}.') + + res = list(map(str.casefold, prefix)) if case_insensitive else list(prefix) + return sorted(res, key=len, reverse=True) + + if callable(prefix) and allow_callable: + return _ensure_casefold(ensure_async(prefix), case_insensitive=case_insensitive) + + raise TypeError( + f'Invalid prefix {prefix!r}. Only strings, collections of strings, ' + f'or functions that return them are allowed.' + ) + + async def get_prefix(self, message: Message, *, prefix: BasePrefixT = None) -> Optional[BasePrefixT]: + """Gets the prefix, or list of prefixes from the given message. + If the message does not start with a prefix, ``None`` is returned. + + Parameters + ---------- + message: :class:`~.Message` + The message to get the prefix from. + + Returns + ------- + Union[str, List[str]] + """ + prefix = self.prefix if prefix is None else prefix + content = message.content + + if self._prefix_case_insensitive: + content = content.casefold() + + if isinstance(prefix, str): + if content.startswith(prefix): + return prefix + else: + return None + + if isinstance(prefix, list): + try: + return next(filter(lambda pf: content.startswith(pf), prefix)) + except StopIteration: + return None + + if callable(prefix): + prefix = await prefix(self, message) + return self.get_prefix(message, prefix=prefix) + + return None + + async def get_context(self, message: Message, *, cls: Type[Context] = Context) -> Context: + """Parses a :class:`~.Context` out of a message. + + If the message is not a command, + partial context with only the ``message`` parameter is returned. + + If an error occurs during parsing, + attributes of the returned context may remain as ``None``. + + Parameters + ---------- + message: :class:`~.Message` + The message to get the context from. + cls: Type[:class:`~.Context`] + The context subclass to use. Defaults to :class:`~.Context`. + + Returns + ------- + :class:`~.Context` + """ + ctx = cls(self, message) + prefix = await self.get_prefix(message) + + if prefix is None: + return ctx + + ctx.prefix = prefix + content = message.content[len(prefix):] + + if self.strip_after_prefix: + content = content.strip() + + reader = ctx.reader = StringReader(content) + ctx.invoked_with = word = reader.next_word(skip_first=False) + ctx.command = self.command_mapping.get(word) + return ctx + + async def invoke(self, ctx: Context) -> None: + """Parses and invokes the given context. + + Checks, cooldowns, hooks, etc. are ran here. + See :meth:`~.Context.reinvoke` for a version that bypasses these. + + Parameters + ---------- + ctx: :class:`~.Context` + The context to invoke. + """ + try: + if not ctx.command: + if ctx.invoked_with is not None: + raise CommandNotFound(ctx) + else: + return + + rest = ctx.reader.rest.strip() + await ctx.command.execute(ctx, rest) + + except CommandBasedError as exc: + self.dispatch('command_error', exc) + + +def _ensure_casefold(func: FunctionPrefixT, /, *, case_insensitive: bool = False) -> FunctionPrefixT: + @functools.wraps(func) + async def wrapper(bot: Bot, message: Message) -> BasePrefixT: + return bot._sanitize_prefix( + await func(bot, message), + case_insensitive=case_insensitive, + allow_callable=False + ) - self.prefix: PrefixT = prefix + return wrapper diff --git a/ferris/plugins/commands/errors.py b/ferris/plugins/commands/errors.py new file mode 100644 index 0000000..669fc30 --- /dev/null +++ b/ferris/plugins/commands/errors.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from typing import Any, Optional, Tuple, TYPE_CHECKING +from ...utils import to_error_string +from ferris.errors import FerrisException + +if TYPE_CHECKING: + from .models import Context + from .parser import Argument + +__all__ = ( + 'CommandBasedError', + 'ArgumentParsingError', + 'ArgumentPreparationError', + 'ArgumentValidationError', + 'ConversionError', + 'ConversionFailure', + 'MissingArgumentError', + 'BadBooleanArgument', + 'BadLiteralArgument', + 'BlacklistedArgument', + 'BadArgument', + 'CommandNotFound', +) + + +class CommandBasedError(FerrisException): + """The base exception raised for errors related to the commands plugin.""" + + +class ArgumentParsingError(CommandBasedError): + """Raised when an error occurs during argument parsing.""" + + +class ArgumentValidationError(ArgumentParsingError): + """Raised when an argument fails validation. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + argument: :class:`~.Argument` + The argument that failed validation. + word: str + The literal string that did not make it past validation. + """ + + def __init__(self, ctx: Context, argument: Argument, word: str) -> None: + self.ctx: Context = ctx + self.argument: Argument = argument + self.word: str = word + super().__init__(f'Argument {argument.name!r} did not pass validation.') + + +class ArgumentPreparationError(ArgumentParsingError): + """Raised when an argument fails preparation. + + In other words, this exception is raised when an error + occurs during :attr:`~.Argument.prepare`. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + argument: :class:`~.Argument` + The argument that failed preparation. + word: str + The string of the word that could not be prepared. + error: :exc:`Exception` + The error that was raised. + """ + + def __init__(self, ctx: Context, argument: Argument, word: str, exc: Exception) -> None: + self.ctx: Context = ctx + self.argument: Argument = argument + self.word: str = word + self.error: Exception = exc + + super().__init__(f'Argument {argument.name!r} failed preparation: {to_error_string(exc)}') + + +class MissingArgumentError(ArgumentParsingError): + """Raised when an argument is missing. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + argument: :class:`~.Argument` + The argument that was missing. + """ + + def __init__(self, ctx: Context, argument: Argument) -> None: + self.ctx: Context = ctx + self.argument: Argument = argument + super().__init__(f'Missing required argument {argument.name!r}.') + + +class ConversionError(ArgumentParsingError): + """Raised when an error occurs during argument conversion.""" + + +class ConversionFailure(ConversionError): + """Raised when an unhandled error occurs during argument conversion. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + argument: :class:`~.Argument` + The argument that failed conversion. + word: str + The literal argument string that failed conversion. + errors: Tuple[:exc:`Exception`, ...] + The errors that were raised. + """ + + def __init__(self, ctx: Context, argument: Argument, word: str, *errors: Exception) -> None: + self.ctx: Context = ctx + self.argument: Argument = argument + self.word: str = word + + self.errors: Tuple[Exception, ...] = errors + super().__init__(f'Argument {argument.name!r} failed conversion: {to_error_string(self.error)}') + + @property + def error(self) -> Optional[Exception]: + """The most recent error raised, or ``None``.""" + try: + return self.errors[-1] + except IndexError: + return None + + +class BadArgument(ConversionError): + """A user defined error that should be raised if an argument cannot + be converted during argument conversion.""" + + +class BadBooleanArgument(BadArgument): + """Raised when converting an argument to a boolean fails.""" + + def __init__(self, argument: str) -> None: + super().__init__(f'{argument!r} cannot be represented as a boolean.') + + +class BadLiteralArgument(BadArgument): + """Raised when converting to a ``Literal`` generic fails.""" + + def __init__(self, argument: str, choices: Tuple[Any, ...]) -> None: + self.argument: str = argument + self.choices: Tuple[Any, ...] = choices + super().__init__(f'{argument!r} is not a valid choice.') + + +class BlacklistedArgument(BadArgument): + """Raised when converting to a ``Not`` generic is successful.""" + + def __init__(self, argument: str, blacklist: type) -> None: + self.argument: str = argument + self.blacklist: type = blacklist + super().__init__(f'{argument!r} should not be castable into {blacklist!r}.') + + +class CommandNotFound(CommandBasedError): + """Raised when a command is not found. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + invoked_with: str + The string that did not correspond to a command. + """ + + def __init__(self, ctx: Context) -> None: + self.ctx: Context = ctx + self.invoked_with: str = ctx.invoked_with + super().__init__(f'Command {self.invoked_with!r} is not found.') diff --git a/ferris/plugins/commands/models.py b/ferris/plugins/commands/models.py index 07e118a..9ad98a4 100644 --- a/ferris/plugins/commands/models.py +++ b/ferris/plugins/commands/models.py @@ -1,34 +1,34 @@ from __future__ import annotations -from typing import ( - Any, - Awaitable, - Generic, - List, - Optional, - TYPE_CHECKING, - TypeVar, - Union, -) +import inspect +from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Collection, Dict, + Generic, List, Optional, Tuple, TypeVar) + +from ...message import Message +from .parser import Parser R = TypeVar('R') if TYPE_CHECKING: + from ferris import Channel, Guild, Message, User from typing_extensions import Concatenate, ParamSpec - from ferris.channel import Channel - from ferris.guild import Guild - from ferris.message import Message - from ferris.member import Member - from ferris.user import User + from .core import Bot + from .parser import StringReader P = ParamSpec('P') - CommandCallbackT = Callable[Concatenate['Context', P], Awaitable[R]] + CommandCallback = Callable[Concatenate['Context', P], Awaitable[R]] + ErrorCallback = Callable[['Context', Exception], Awaitable[Any]] else: P = TypeVar('P') +__all__ = ( + 'Command', + 'Context', +) + -class Command(Generic[P, R]): +class Command(Parser, Generic[P, R]): """Represents a command. Attributes @@ -37,97 +37,254 @@ class Command(Generic[P, R]): The name of this command. aliases: List[str] A list of aliases for this command. - callback: Callable[Concatenate[:class:`.Context`, P], Awaitable[R]] - The command callback for this command. + usage: str + The custom usage string for this command, or ``None``. + See :attr:`~.Command.signature` for an auto-generated version. + error_callback: Optional[Callable[[:class:`~.Context`, Exception], Any]] + The callback for when an error is raise during command invokation. + This could be ``None``. """ def __init__( - self, name: str, aliases: List[str], callback: CommandCallbackT + self, + callback: CommandCallback, + *, + name: str, + aliases: Collection[str], + brief: Optional[str] = None, + description: Optional[str] = None, + usage: Optional[str] = None, + **attrs ) -> None: + self.error_callback: Optional[ErrorCallback] = None + self.name: str = name - self.aliases: List[str] = aliases - self.callback: CommandCallbackT = callback - self.on_error: Optional[Callable[[Context, Exception], Any]] = None + self.aliases: List[str] = list(aliases) + self.usage: Optional[str] = usage + + self._brief: Optional[str] = brief + self._description: Optional[str] = description + + self._metadata: Dict[str, Any] = attrs + + super().__init__() + self.overload(callback) + + def __str__(self) -> str: + return self.qualified_name + + def __repr__(self) -> str: + return f'' + + def __hash__(self) -> int: + return hash(self.qualified_name) + + @property + def brief(self) -> str: + """str: A short description of this command. + + If no brief is set, the first line of the :attr:`~.Command.description` + is used instead. + """ + if self._brief is not None: + return self._brief + + try: + return self.description.splitlines()[0] + except IndexError: + return '' + + @property + def description(self) -> str: + """str: A detailed description of this command. + + If no description is set, the command's callback docstring is used instead. + If no docstring exists, an empty string is returned. + """ + if self._description is not None: + return self._description + + doc = self.callback.__doc__ + if not doc: + return '' + + return inspect.cleandoc(doc) + + @property + def qualified_name(self) -> str: + """str: The qualified name for this command.""" + return self.name + + @property + def signature(self) -> str: + """str: The signature, or "usage string" of this command. + + If :attr:`~.Command.usage` is set, it will be returned. + Else, it is automatically generated. + """ + return self.usage or super().signature + + def error(self, func: ErrorCallback) -> None: + """Registers an error handler for this command. + If no errors are raised here, the error is suppressed. + Else, `on_command_error` is raised. + + Function signature should be of ``async (ctx: Context, error: Exception) -> Any``. + """ + self.error_callback = func - def error(self, func: Callable[[Context, Exception], Any]) -> None: - """Adds an error handler to this command. + async def execute(self, ctx: Context, content: str) -> None: + """|coro| - Example - ------- - .. code:: python3 + Parses and executes this command with the given context and argument content. - @bot.command() - async def raise_error(ctx: Context) -> None: - int('a') # Will raise ValueError + .. note:: + The prefix and command must be removed from content before-hand. - @raise_error.error - async def error_handler(ctx: Context, exc: Exception) -> None: - if isinstance(exc, ValueError): - await ctx.send('Got ValueError!') + Parameters + ---------- + ctx: :class:`~.Context` + The context to invoke this command with. + content: str + The content of the arguments to parse. """ - self.on_error = func + try: + ctx.callback, _, _ = await self._parse(content, ctx=ctx) + except Exception as exc: + if self.error_callback: + try: + await self.error_callback(ctx, exc) + except Exception as new_exc: + exc = new_exc + else: + return + + ctx.bot.dispatch('command_error', ctx, exc) + else: + await ctx.reinvoke() + + async def invoke(self, ctx: Context, /, *args: P.args, **kwargs: P.kwargs) -> None: + """|coro| - async def invoke(self, ctx: Context, *args: P.args, **kwargs: P.kwargs) -> None: - """Invokes this command. + Invokes this command with the given context. Parameters ---------- - ctx: :class:`.Context` - The context to invoke the command with. - *args: P.args - The positional arguments to pass into the command callback. - **kwargs: P.kwargs - The keyword arguments to pass into the command callback. + ctx: :class:`~.Context` + The context to invoke this command with. + *args + The positional arguments to pass into the callback. + **kwargs + The keyword arguments to pass into the callback. """ try: - await self.callback(ctx, *args, **kwargs) - # dispatch `on_command` here... + ctx.bot.dispatch('command', ctx) + await ctx.callback(ctx, *args, **kwargs) except Exception as exc: - if self.on_error: + if self.error_callback: try: - await self.on_error(ctx, exc) + await self.error_callback(ctx, exc) except Exception as new_exc: exc = new_exc else: return - # dispatch `on_command_error` here... + + ctx.bot.dispatch('command_error', ctx, exc) else: - # dispatch `on_command_success` here... - ... + ctx.bot.dispatch('command_success', ctx) finally: - # dispatch `on_command_complete` here... - ... + ctx.bot.dispatch('command_complete', ctx) class Context: - """Represents the context for a command. + """Represents the context for when a command is invoked. Attributes ---------- - command: Optional[:class:`.Command`] - The command invoked. May be ``None``. + bot: :class:`~.Bot` + The bot that created this context. + message: :class:`~.Message` + The message that invoked this context. + prefix: str + The prefix used to invoke this command. + Could be, but rarely is ``None``. + invoked_with: str + The command name used to invoke this command. + This could be used to determine which alias invoked this command. + Could be ``None``. + command: :class:`~.Command` + The command invoked. Could be ``None``. + callback: Callable + The parsed command callback that will be used. Could be ``None``. + args: tuple + The arguments used to invoke this command. Could be ``None``. + kwargs: Dict[str, Any] + The keyword arguments used to invoke this command. Could be ``None``. + reader: :class:`~.StringReader` + The string-reader that was used to parse this command. Could be ``None``. """ - def __init__(self, *, message: Message) -> None: - self._message: Message = message - self.command: Command = None + def __init__(self, bot: Bot, message: Message) -> None: + self.bot: Bot = bot + self.message: Message = message - @property - def message(self) -> Message: - """:class:`~.Message`: The invocation message for this context.""" - return self._message + self.prefix: Optional[str] = None + self.invoked_with: Optional[str] = None - @property - def author(self) -> Union[User, Member]: - """Union[:class:`~.User`, :class:`~.Member`]: The author of this context.""" - return self._message.author + self.command: Optional[Command] = None + self.callback: Optional[CommandCallback] = None + self.args: Optional[tuple] = None + self.kwargs: Optional[Dict[str, Any]] = None + self.reader: Optional[StringReader] = None - @property - def channel(self) -> Channel: - """:class:`~.Channel`: The channel of this context.""" - return self._message.channel + self.send: Optional[Callable[[str], Message]] = getattr(self.channel, 'send', None) + + def __repr__(self) -> str: + return f'' + + async def invoke(self, command: Command[P, Any], /, *args: P.args, **kwargs: P.kwargs) -> None: + """|coro| + + Invokes the given command with this context. + .. note:: No checks will be called here. + + Parameters + ---------- + command: :class:`~.Command` + The context to invoke this command with. + *args + The positional arguments to pass into the callback. + **kwargs + The keyword arguments to pass into the callback. + """ + self.command = command + self.args = args + self.kwargs = kwargs + + await command.invoke(self, *args, **kwargs) + + async def reinvoke(self) -> None: + """|coro| + + Re-invokes this command with the same arguments. + + .. note:: No checks will be called here. + """ + await self.invoke(self.command, *self.args, **self.kwargs) + + @property + def channel(self) -> Optional[Channel]: + """Optional[:class:`~.Channel`]: The channel that this context was invoked in.""" + return self.message.channel + + @property + def author(self) -> Optional[User]: + """Optional[:class:`~.User`]: The user that this context was invoked by.""" + return self.message.author + @property - def guild(self) -> Guild: - """:class:`~.Guild`: The guild of this context.""" - return self._message.guild + def guild(self) -> Optional[Guild]: + """Optional[:class:`~.Guild`]: The guild that this context was invoked in.""" + return self.message.guild + diff --git a/ferris/plugins/commands/parser.py b/ferris/plugins/commands/parser.py new file mode 100644 index 0000000..bc0e4b8 --- /dev/null +++ b/ferris/plugins/commands/parser.py @@ -0,0 +1,960 @@ +# Improved version of https://github.com/wumpus-py/argument-parsing/blob/master/argument_parser/core.py + +from __future__ import annotations + +import inspect +from abc import ABC, abstractmethod +from enum import Enum +from itertools import chain +from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Dict, Generic, + Iterable, List, Literal, Optional, Tuple, Type, TypeVar, + Union, overload) + +from ...utils import P, ensure_async +from .converters import ChannelConverter, GuildConverter, UserConverter +from .errors import * + +from ferris import Channel, Guild, User + +ConverterOutputT = TypeVar('ConverterOutputT') +GreedyT = TypeVar('GreedyT') +LiteralT = TypeVar('LiteralT') +NotT = TypeVar('NotT') + +if TYPE_CHECKING: + from ferris.types.plugins.commands import ParserCallbackProto + + from .models import Context + + ArgumentPrepareT = Callable[[str], str] + ConverterT = Union['Converter', Type['Converter'], Callable[[str], ConverterOutputT]] + ParserCallback = ParserCallbackProto + + ArgumentT = TypeVar('ArgumentT', bound='Argument') + BlacklistT = TypeVar('BlacklistT', bound=ConverterT) # type: ignore + ParserT = TypeVar('ParserT', bound=Union['_Subparser', 'Parser']) + +CONVERTERS_MAPPING = { + Channel: ChannelConverter, + Guild: GuildConverter, + User: UserConverter, +} + +_NoneType: Type[None] = type(None) + +__all__ = ( + 'Parser', + 'StringReader', + 'Greedy', + 'Not', + 'ConsumeType', + 'Quotes', + 'Argument', + 'Converter', + 'converter', +) + + +class _NullType: + def __bool__(self) -> bool: + return False + + def __repr__(self) -> str: + return 'NULL' + + __str__ = __repr__ + + +_NULL = _NullType() + + +class ConsumeType(Enum): + """|enum| + + An enumeration of argument consumption types. + + Attributes + ---------- + default + The default and "normal" consumption type. + consume_rest + Consumes the string, including quotes, until the end. + list + Consumes in a similar fashion to the default consumption type, + but will consume like this all the way until the end. + If an error occurs, an error will be raised. + tuple + :attr:`~.ConsumeType.list` except that the result is a tuple, + rather than a list. + greedy + Like :attr:`~.ConsumeType.list`, but it stops consuming + when an argument fails to convert, rather than raising an error. + """ + + default = 'default' + consume_rest = 'consume_rest' + list = 'list' + tuple = 'tuple' + greedy = 'greedy' + + +class Quotes: + """Builtin quote mappings. All attributes are instances of ``Dict[str, str]``. + + Attributes + ---------- + default + The default quote mapping. + + ``"`` -> ``"`` + ``'`` -> ``'`` + extended + An extended quote mapping that supports + quotes from other languages/locales. + + ``'"'`` -> ``'"'`` + ``'`` -> ``'`` + ``'\u2018'`` -> ``'\u2019'`` + ``'\u201a'`` -> ``'\u201b'`` + ``'\u201c'`` -> ``'\u201d'`` + ``'\u201e'`` -> ``'\u201f'`` + ``'\u2e42'`` -> ``'\u2e42'`` + ``'\u300c'`` -> ``'\u300d'`` + ``'\u300e'`` -> ``'\u300f'`` + ``'\u301d'`` -> ``'\u301e'`` + ``'\ufe41'`` -> ``'\ufe42'`` + ``'\ufe43'`` -> ``'\ufe44'`` + ``'\uff02'`` -> ``'\uff02'`` + ``'\uff62'`` -> ``'\uff63'`` + ``'\xab'`` -> ``'\xbb'`` + ``'\u2039'`` -> ``'\u203a'`` + ``'\u300a'`` -> ``'\u300b'`` + ``'\u3008'`` -> ``'\u3009'`` + """ + + default = { + '"': '"', + "'": "'" + } + + extended = { + '"': '"', + "'": "'", + "\u2018": "\u2019", + "\u201a": "\u201b", + "\u201c": "\u201d", + "\u201e": "\u201f", + "\u2e42": "\u2e42", + "\u300c": "\u300d", + "\u300e": "\u300f", + "\u301d": "\u301e", + "\ufe41": "\ufe42", + "\ufe43": "\ufe44", + "\uff02": "\uff02", + "\uff62": "\uff63", + "\u2039": "\u203a", + "\u300a": "\u300b", + "\u3008": "\u3009", + "\xab": "\xbb", + } + + +class Greedy(Generic[GreedyT]): + """Represents a generic type annotation that sets the + consume-type of the argument to :attr:`~.ConsumeType.greedy`. + + Examples + -------- + .. code:: python3 + + @command() + async def massban(ctx, members: Greedy[Member], *, reason=None): + \"""Bans multiple people at once.\""" + for member in members: + await member.ban(reason=reason) + """ + + +class Not(Generic[NotT]): + """A special generic type annotation that fails conversion + if the given converter converts successfully. + + Examples + -------- + .. code:: python3 + + @command() + async def purchase(ctx, item: Not[int], quantity: int): + \"""Purchases an item.\""" + + @purchase.overload + async def purchase(ctx, quantity: int, *, item: str): + return await purchase.callback(ctx, item, quantity) + """ + + +class Converter(Generic[ConverterOutputT], ABC): + """A class that aids in making class-based converters.""" + + __is_converter__: bool = True + + @abstractmethod + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: + """|coro| + + The core conversion of the argument. + This must be implemented, or :exc:`NotImplementedError` will be raised. + + Parameters + ---------- + ctx: :class:`~.Context` + The command context being parsed. + Note that some attributes could be ``None``. + argument: str + The argument to convert. + + Returns + ------- + Any + """ + raise NotImplementedError + + # noinspection PyUnusedLocal + async def validate(self, ctx: Context, argument: str) -> bool: + """|coro| + + The argument validation check to use. + This will be called before convert and raise a :exc:`~.ValidationError` + if it fails. + + This exists to encourage cleaner code. + + Parameters + ---------- + ctx: :class:`~.Context` + The command context being parsed. + Note that some attributes could be ``None``. + argument: str + The argument to validate. + + Returns + ------- + bool + """ + return True + + +class LiteralConverter(Converter[LiteralT]): + def __init__(self, *choices: LiteralT) -> None: + self._valid: Tuple[LiteralT, ...] = choices + + def __repr__(self) -> None: + return f'<{self.__class__.__name__} valid={self._valid!r}>' + + async def convert(self, ctx: Context, argument: str) -> LiteralT: + for possible in self._valid: + p_type = type(possible) + try: + casted = await _convert_one(ctx, argument, p_type) + except ConversionError: + continue + else: + if casted not in self._valid: + raise BadLiteralArgument(argument, self._valid) + + return casted + + raise BadLiteralArgument(argument, self._valid) + + +class _Not(Converter[str]): + def __init__(self, *entities: BlacklistT) -> None: + self._blacklist: Tuple[BlacklistT, ...] = entities + + def __repr__(self) -> None: + return f'<{self.__class__.__name__} blacklist={self._blacklist!r}>' + + async def convert(self, ctx: Context, argument: str) -> str: + for entity in self._blacklist: + try: + await _convert_one(ctx, argument, entity) + except ConversionError: + return argument + else: + raise BlacklistedArgument(argument, entity) + + +_SC = 'Tuple[Tuple[ConverterT], Optional[bool], Optional[ConsumeType]]' + + +def _sanitize_converter(converter: ConverterT, /, optional: bool = None, consume: ConsumeType = None) -> _SC: + origin = getattr(converter, '__origin__', False) + args = getattr(converter, '__args__', False) + + if origin and args: + if origin is Union: + if _NoneType in args or None in args: + # This is an optional type. + optional = True + args = tuple(arg for arg in args if arg is not _NoneType) + + return tuple( + chain.from_iterable(_sanitize_converter(arg)[0] for arg in args) + ), optional, consume + + if origin is Literal: + converter = LiteralConverter(*args) + + if origin is List: + consume = ConsumeType.list + return _sanitize_converter(args[0], optional, consume) + + if origin is Greedy: + consume = ConsumeType.greedy + + if origin is Not: + converter = _Not(*args) + + if inspect.isclass(converter) and issubclass(converter, Converter): + converter = converter() + + return (converter,), optional, consume + + +def _convert_bool(argument: str) -> bool: + if isinstance(argument, bool): + return argument + + argument = argument.lower() + if argument in {'true', 't', 'yes', 'y', 'on', 'enable', 'enabled', '1'}: + return True + if argument in {'false', 'f', 'no', 'n', 'off', 'disable', 'disabled', '0'}: + return False + + raise BadBooleanArgument(argument) + + +async def _convert_one(ctx: Context, argument: str, converter: ConverterT) -> ConverterOutputT: + if converter is bool: + return _convert_bool(converter) + + try: + module = converter.__module__ + except AttributeError: + pass + else: + if module is not None and module.startswith('ferris') and not module.endswith('Converter'): + converter = CONVERTERS_MAPPING.get(converter, converter) + + try: + if inspect.isclass(converter) and issubclass(converter, Converter): + if not inspect.ismethod(converter.convert): + converter = converter() + + if getattr(converter, '__is_converter__', False): + return await converter.convert(ctx, argument) + + return converter(argument) + + except Exception as exc: + raise ConversionError(exc) + + +async def _convert(ctx: Context, argument: Argument, word: str, converters: Iterable[ConverterT]) -> ConverterOutputT: + errors = [] + for converter in converters: + try: + result = await _convert_one(ctx, word, converter) + + if getattr(converter, '__is_converter__', False) and not await converter.validate(ctx, result): + raise ArgumentValidationError(ctx, argument, word) + + return result + + except ConversionError as exc: + errors.append(exc) + + raise ConversionFailure(ctx, argument, word, *errors) + + +def _prepare(ctx: Context, argument: Argument, word: str) -> str: + callback = argument.prepare + + if callback is None: + return word + + try: + return str(callback(word)) + except Exception as exc: + raise ArgumentPreparationError(ctx, argument, word, exc) + + +class Argument: + """Represents a positional argument. + + Although these are automatically constructed, you can + explicitly construct these and use them as type-annotations + in your command signatures. + + Parameters + ---------- + converter + The converter for this argument. This cannot be mixed with `converters`. + *converters + A list of converters for this argument. This cannot be mixed with `converter`. + name: str + The name of this argument. + signature: str + A custom signature for this argument. Will be auto-generated if not given. + default + The default value of this argument. + optional: bool + Whether or not this argument is optional. + description: str + A description about this argument. + consume_type: :class:`~.ConsumeType` + The consumption type of this argument. + quoted: bool + Whether or not this argument can be pass in with quotes. + quotes: Dict[str, str] + A mapping of start-quotes to end-quotes + of all supported quotes for this argument. + **kwargs + Extra kwargs to pass in for metadata. + + Attributes + ---------- + name: str + The name of this argument. + default + The default value of this argument. + optional: bool + Whether or not this argument is optional. + description: str + A description about this argument. + consume_type: :class:`~.ConsumeType` + The consumption type of this argument. + quoted: bool + Whether or not this argument can be pass in with quotes. + quotes: Dict[str, str] + A mapping of start-quotes to end-quotes + of all supported quotes for this argument. + prepare: Callable[[str], str] + A callable that takes a string, which will + prepare the argument before conversion. + + Examples + -------- + .. code:: python3 + + @command() + async def translate( + ctx, + language: Argument( + Literal['spanish', 'french', 'chinese'], + prepare=str.lower + ), + text: Argument(str, consume_type=ConsumeType.consume_rest) + ): + ... + """ + + if TYPE_CHECKING: + @overload + def __init__( + self, + /, + converter: ConverterT, + *, + name: str = None, + signature: str = None, + default: Any = _NULL, + optional: bool = None, + description: str = None, + consume_type: Union[ConsumeType, str] = ConsumeType.default, + quoted: bool = None, + quotes: Dict[str, str] = None, + prepare: ArgumentPrepareT = None, + **kwargs + ) -> None: + ... + + @overload + def __init__( + self, + /, + *converters: ConverterT, + name: str = None, + alias: str = None, + signature: str = None, + default: Any = _NULL, + optional: bool = None, + description: str = None, + consume_type: Union[ConsumeType, str] = ConsumeType.default, + quoted: bool = None, + quotes: Dict[str, str] = None, + prepare: ArgumentPrepareT = None, + **kwargs + ) -> None: + ... + + def __init__( + self, + /, + *converters: ConverterT, + name: str = None, + signature: str = None, + default: Any = _NULL, + optional: bool = None, + description: str = None, + converter: ConverterT = None, + consume_type: Union[ConsumeType, str] = ConsumeType.default, + quoted: bool = None, + quotes: Dict[str, str] = None, + prepare: ArgumentPrepareT = None, + **kwargs + ) -> None: + actual_converters = str, + + if converters and converter is not None: + raise ValueError('Converter kwarg cannot be used when they are already passed as positional arguments.') + + if len(converters) == 1: + converter = converters[0] + converters = () + + if converters or converter: + if converters: + actual_converters, optional_, consume = _sanitize_converter(converters[0]) + actual_converters += _sanitize_converter(converters[1:])[0] + elif converter: + actual_converters, optional_, consume = _sanitize_converter(converter) + else: + raise ValueError('Parameter mismatch') + + optional = optional if optional is not None else optional_ + consume_type = consume_type if consume_type is not None else consume + + self._param_key: Optional[str] = None + self._param_kwarg_only: bool = False + self._param_var_positional: bool = False + + self.name: str = name + self.description: str = description + self.default: Any = default + self.prepare: ArgumentPrepareT = prepare + + self.consume_type: ConsumeType = ( + consume_type if isinstance(consume_type, ConsumeType) else ConsumeType(consume_type) + ) + + self.optional: bool = optional if optional is not None else False + self.quoted: bool = quoted if quoted is not None else consume_type is not ConsumeType.consume_rest + self.quotes: Dict[str, str] = quotes if quotes is not None else Quotes.default + + self._signature: str = signature + self._converters: Tuple[ConverterT, ...] = actual_converters + + self._kwargs: Dict[str, Any] = kwargs + + def __repr__(self) -> str: + return f'' + + def __hash__(self) -> int: + return hash(id(self)) + + @property + def converters(self) -> List[ConverterT]: + """List[Union[type, :class:`~.Converter`]]: A list of this argument's converters.""" + return list(self._converters) + + @property + def signature(self) -> str: + """str: The signature of this argument.""" + if self._signature is not None: + return self._signature + + start, end = '[]' if self.optional or self.default is not _NULL else '<>' + + suffix = '...' if self.consume_type in ( + ConsumeType.list, ConsumeType.tuple, ConsumeType.greedy + ) else '' + + default = f'={self.default}' if self.default is not _NULL else '' + return start + str(self.name) + default + suffix + end + + @signature.setter + def signature(self, value: str, /) -> str: + self._signature = value + + @classmethod + def _from_parameter(cls: Type[ArgumentT], param: inspect.Parameter, /) -> ArgumentT: + def finalize(argument: ArgumentT) -> ArgumentT: + if param.kind is param.KEYWORD_ONLY: + argument._param_kwarg_only = True + + if param.kind is param.VAR_POSITIONAL: + argument._param_var_positional = True + + argument._param_key = store = param.name + if not argument.name: + argument.name = store + + return argument + + kwargs = {'name': param.name} + + if param.annotation is not param.empty: + if isinstance(param.annotation, cls): + return finalize(param.annotation) + + kwargs['converter'] = param.annotation + + if param.default is not param.empty: + kwargs['default'] = param.default + + if param.kind is param.KEYWORD_ONLY: + kwargs['consume_type'] = ConsumeType.consume_rest + + elif param.kind is param.VAR_POSITIONAL: + kwargs['consume_type'] = ConsumeType.tuple + + return finalize(cls(**kwargs)) + + +class StringReader: + """Helper class to aid with parsing strings. + + Parameters + ---------- + string: str + The string to parse. + quotes: Dict[str, str] + A mapping of start-quotes to end-quotes. + Defaults to :attr:`~.Quotes.default` + """ + + class EOF: + """The pointer has gone past the end of the string.""" + + def __init__(self, string: str, /, *, quotes: Dict[str, str] = None) -> None: + self.quotes: Dict[str, str] = quotes or Quotes.default + self.buffer: str = string + self.index: int = -1 + + def seek(self, index: int, /) -> str: + """Seek to an index in the string. + + Parameters + ---------- + index: int + The index to seek to. + + Returns + ------- + str + """ + self.index = index + return self.current + + @property + def current(self) -> Union[str, Type[EOF]]: + """str: The current character this reader is pointing to.""" + try: + return self.buffer[self.index] + except IndexError: + return self.EOF + + @property + def eof(self) -> bool: + """bool: Whether or not this reader has reached the end of the string.""" + return self.current is self.EOF + + def previous_character(self) -> str: + """Seeks to the previous character in the string.""" + return self.seek(self.index - 1) + + def next_character(self) -> str: + """Seeks to the next character in the string.""" + return self.seek(self.index + 1) + + @property + def rest(self) -> str: + """str: The rest of the characters in the string.""" + result = self.buffer[self.index:] if self.index != -1 else self.buffer + self.index = len(self.buffer) # Force an EOF + return result + + @staticmethod + def _is_whitespace(char: str, /) -> bool: + if char is ...: + return False + return char.isspace() + + def skip_to_word(self) -> None: + """Skips to the beginning of the next word.""" + char = ... + while self._is_whitespace(char): + char = self.next_character() + + def next_word(self, *, skip_first: bool = True) -> str: + """Returns the next word in the string. + + Parameters + ---------- + skip_first: bool + Whether or not to call :meth:`~.StringReader.skip_to_word` + beforehand. Defaults to ``True``. + + Returns + ------- + str + """ + char = ... + buffer = '' + + if skip_first: + self.skip_to_word() + + while not self._is_whitespace(char): + char = self.next_character() + if self.eof: + return buffer + + buffer += char + + buffer = buffer[:-1] + return buffer + + def next_quoted_word(self, *, skip_first: bool = True) -> str: + """Returns the next quoted word in the string. + + Parameters + ---------- + skip_first: bool + Whether or not to call :meth:`~.StringReader.skip_to_word` + beforehand. Defaults to ``True``. + + Returns + ------- + str + """ + if skip_first: + self.skip_to_word() + + first_char = self.next_character() + + if first_char not in self.quotes: + self.previous_character() + return self.next_word(skip_first=False) + + end_quote = self.quotes[first_char] + + char = ... + buffer = '' + + while char != end_quote or self.buffer[self.index - 1] == '\\': + char = self.next_character() + if self.eof: + return buffer + + buffer += char + + self.next_character() + buffer = buffer[:-1] + return buffer + + +class _Subparser: + """Parses one specific overload.""" + + def __init__(self, arguments: List[Argument] = None, *, callback: ParserCallback = None): + self._arguments: List[Argument] = arguments or [] + self.callback: Optional[ParserCallback] = callback + + def __hash__(self) -> int: + return hash(id(self)) + + def add_argument(self, argument: Argument, /) -> None: + self._arguments.append(argument) + + @property + def signature(self) -> str: + """str: The signature (or "usage string") for this overload.""" + return ' '.join(arg.signature for arg in self._arguments) + + @classmethod + def from_function(cls: Type[_Subparser], func: ParserCallback, /) -> _Subparser: + """Creates a new :class:`~.Parser` from a function.""" + params = list(inspect.signature(func).parameters.values()) + + if len(params) < 1: + raise TypeError('Command callback must have at least one context parameter.') + + self = cls(callback=func) + for param in params[1:]: + self.add_argument(Argument._from_parameter(param)) + + return self + + async def parse(self, text: str, /, ctx: Context) -> Tuple[List[Any], Dict[str, Any]]: + # Return a tuple (args: list, kwargs: dict) + # Execute as callback(ctx, *args, **kwargs) + + args = ctx.args = [] + kwargs = ctx.kwargs = {} + reader = ctx.reader = StringReader(text) + + def append_value(argument: Argument, value: Any) -> None: + if argument._param_kwarg_only: + kwargs[argument._param_key] = value + elif argument._param_var_positional: + args.extend(value) + else: + args.append(value) + + i = 0 + for i, argument in enumerate(self._arguments, start=1): + if reader.eof: + i -= 1 + break + + start = reader.index + + if argument.consume_type not in (ConsumeType.consume_rest, ConsumeType.default): + # Either list, tuple, or greedy + result = [] + + while not reader.eof: + word = reader.next_quoted_word() if argument.quoted else reader.next_word() + word = _prepare(ctx, argument, word) + try: + word = await _convert(ctx, argument, word, argument.converters) + except ConversionError as exc: + if argument.consume_type is not ConsumeType.greedy: + raise exc + break + else: + result.append(word) + + if argument.consume_type is ConsumeType.tuple: + result = tuple(result) + + append_value(argument, result) + continue + + if argument.consume_type is ConsumeType.consume_rest: + word = reader.rest.strip() + else: + word = reader.next_quoted_word() if argument.quoted else reader.next_word() + + word = _prepare(ctx, argument, word) + + try: + word = await _convert(ctx, argument, word, argument.converters) + except ConversionError as exc: + if argument.optional: + default = argument.default if argument.default is not _NULL else None + append_value(argument, default) + reader.seek(start) + continue + + raise exc + else: + append_value(argument, word) + + for argument in self._arguments[i:]: + if argument.default is not _NULL: + append_value(argument, argument.default) + elif argument.optional: + append_value(argument, None) + else: + raise MissingArgumentError(ctx, argument) + + return args, kwargs + + +class Parser: + """The main class that parses arguments.""" + + def __init__(self, *, overloads: List[_Subparser] = None) -> None: + self._overloads: List[_Subparser] = overloads or [] + + @property + def _main_parser(self) -> _Subparser: + if not len(self._overloads): + self._overloads.append(_Subparser()) + + return self._overloads[0] + + @property + def callback(self) -> ParserCallback: + """Callable[[:class:`~.Context`, Any, ...], Any]: The callback of this command.""" + return self._main_parser.callback + + @callback.setter + def callback(self, func: ParserCallback) -> None: + func = ensure_async(func) + overload = _Subparser.from_function(func) + try: + self._overloads[0] = overload + except IndexError: + self._overloads = [overload] + + @property + def signature(self) -> str: + """str: The usage string of this command.""" + return self._main_parser.signature + + @property + def arguments(self) -> List[Argument]: + """List[:class:`~.Argument`]: A list of arguments this command takes.""" + return self._main_parser._arguments + + def overload(self: ParserT, func: ParserCallback, /) -> ParserT: + """Adds a command overload to this function.""" + func = ensure_async(func) + result = _Subparser.from_function(func) + self._overloads.append(result) + return self + + @classmethod + def from_function(cls: Type[ParserT], func: ParserCallback, /) -> ParserT: + func = ensure_async(func) + parser = _Subparser.from_function(func) + return cls(overloads=[parser]) + + async def _parse(self, text: str, /, ctx: Context) -> Tuple[ParserCallback, List[Any], Dict[str, Any]]: + errors = [] + for overload in self._overloads: + try: + return overload.callback, *await overload.parse(text, ctx=ctx) + except Exception as exc: + errors.append(exc) + + # Maybe do something with the other errors? + raise errors[0] + + +def converter(func: Callable[[Context, str], Awaitable[ConverterOutputT]]) -> Type[Converter[ConverterOutputT]]: + """A decorator that helps convert a function into a converter. + + Examples + -------- + .. code:: python3 + + @converter + async def reverse_string(ctx, argument): + return argument[::-1] + + @command('reverse') + async def reverse_command(ctx, *, text: reverse_string): + \"""Reverses a string of text.\""" + await ctx.send(text) + """ + func = ensure_async(func) + + async def convert(_, ctx: Context, argument: str) -> ConverterOutputT: + return await func(ctx, argument) + + return type('FunctionBasedConverter', (Converter,), {'convert': convert}) diff --git a/ferris/types/__init__.py b/ferris/types/__init__.py index 4b4c132..215a43e 100644 --- a/ferris/types/__init__.py +++ b/ferris/types/__init__.py @@ -9,6 +9,7 @@ from .role import * from .user import * from .ws import * +from .plugins import * Data = Union[ AuthResponse, diff --git a/ferris/types/base.py b/ferris/types/base.py index 523f5f4..14c648e 100644 --- a/ferris/types/base.py +++ b/ferris/types/base.py @@ -1,4 +1,5 @@ -from typing import Protocol, Union, runtime_checkable, Optional +from typing import Protocol, Type, Union, runtime_checkable, Optional +from typing_extensions import TypeAlias __all__ = ('SupportsStr', 'SupportsId', 'Id', 'Snowflake') @@ -8,7 +9,7 @@ def __str__(self) -> str: ... -Snowflake = int +Snowflake: TypeAlias = int @runtime_checkable @@ -18,4 +19,4 @@ class SupportsId(Protocol): id: Snowflake -Id = Optional[Union[SupportsId, Snowflake]] +Id: TypeAlias = Optional[Union[SupportsId, Snowflake]] diff --git a/ferris/types/plugins/__init__.py b/ferris/types/plugins/__init__.py new file mode 100644 index 0000000..7042230 --- /dev/null +++ b/ferris/types/plugins/__init__.py @@ -0,0 +1,3 @@ +from . import commands + +__all__ = ('commands',) \ No newline at end of file diff --git a/ferris/types/plugins/commands/__init__.py b/ferris/types/plugins/commands/__init__.py new file mode 100644 index 0000000..0c1a534 --- /dev/null +++ b/ferris/types/plugins/commands/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, Tuple, Any, Dict + +__all__ = ('ParserCallbackProto',) + +if TYPE_CHECKING: + from ferris.plugins.commands import Context + + +class ParserCallbackProto(Protocol): + def __call__(self, ctx: Context, *args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any: + ... diff --git a/ferris/utils.py b/ferris/utils.py index c0d0618..18b20e3 100644 --- a/ferris/utils.py +++ b/ferris/utils.py @@ -1,5 +1,13 @@ +from __future__ import annotations + +import functools +import inspect +import json +import re +import sys from datetime import datetime -from typing import Any, Callable, Iterable, Optional, TypeVar +from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Iterable, + Optional, TypeVar, Union, overload) from .types import Id, Snowflake @@ -9,10 +17,28 @@ 'get_snowflake_creation_date', 'find', 'dt_to_snowflake', + 'ensure_async', + 'to_error_string', ) + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + P = ParamSpec('P') +else: + P = TypeVar('P') + +R = TypeVar('R') + T = TypeVar('T', covariant=True) +if sys.version_info[:3] >= (3, 9, 0): + A = Callable[P, Awaitable[R]] # type: ignore + F = Callable[P, R] # type: ignore +else: + A = Callable[[P], Awaitable[R]] # type: ignore + F = Callable[[P], R] # type: ignore + # try: # import orjson @@ -21,15 +47,13 @@ # import json -import json - HAS_ORJSON = False FERRIS_EPOCH_MS: int = 1_577_836_800_000 - FERRIS_EPOCH: int = 1_577_836_800 +INVITE_REGEX: re.Pattern = re.compile(r'(https?:\/\/)?(www\.)?(ferris\.sh|ferris\.chat\/invite)\/([A-Za-z0-9]+)') if HAS_ORJSON: @@ -124,3 +148,41 @@ def find(predicate: Callable[[T], Any], iterable: Iterable[T]) -> Optional[T]: if predicate(element): return element return None + +@overload +def ensure_async(func: Callable[P, Awaitable[R]], /) -> Callable[P, Awaitable[R]]: + ... + + +@overload +def ensure_async(func: Callable[P, R], /) -> Callable[P, Awaitable[R]]: + ... + + +def ensure_async(func: Union[A, F], /) -> A: + """Ensures that the given function is asynchronous. + In other terms, if the function is already async, it will stay the same. + Else, it will be converted into an async function. (Note that it will still be ran synchronously.) + """ + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + maybe_coro = func(*args, **kwargs) + + if inspect.isawaitable(maybe_coro): + return await maybe_coro + + return maybe_coro + + return wrapper + + +def to_error_string(exc: Exception, /) -> str: + """Formats the given error into ``{error name}: {error text}``, + e.g. ``ValueError: invalid literal for int() with base 10: 'a'``. + If no error text exists, only the error name will be returned, + e.g. just ``ValueError``. + """ + if str(exc).strip(): + return '{0.__class__.__name__}: {0}'.format(exc) + else: + return exc.__class__.__name__ diff --git a/setup.py b/setup.py index 919d1b6..569f66d 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "Issue tracker": "https://github.com/FerrisChat/ferriswheel/issues/new", }, version=version, - packages=["ferris", "ferris.types", "ferris.plugins", "ferris.plugins.commands"], + packages=["ferris", "ferris.types", "ferris.plugins", "ferris.plugins.commands", "ferris.types.plugins", "ferris.types.plugins.commands"], license="MIT", description="An asynchronous Python wrapper around FerrisChat's API", long_description=readme,