From 55e10216d263b0df652fd091798dcd937c27b055 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Wed, 24 Nov 2021 16:45:50 -0800 Subject: [PATCH 01/22] Merge from https://github.com/jay3332/command-handler --- ferris/plugins/commands/__init__.py | 14 +- ferris/plugins/commands/core.py | 351 ++++++++++- ferris/plugins/commands/errors.py | 178 ++++++ ferris/plugins/commands/models.py | 282 +++++++-- ferris/plugins/commands/parser.py | 937 ++++++++++++++++++++++++++++ ferris/types/base.py | 7 +- ferris/utils.py | 51 +- 7 files changed, 1725 insertions(+), 95 deletions(-) create mode 100644 ferris/plugins/commands/errors.py create mode 100644 ferris/plugins/commands/parser.py diff --git a/ferris/plugins/commands/__init__.py b/ferris/plugins/commands/__init__.py index 78af14e..ede9dd9 100644 --- a/ferris/plugins/commands/__init__.py +++ b/ferris/plugins/commands/__init__.py @@ -1,2 +1,14 @@ -from .core import Bot, CommandSink +from . import errors, parser, utils +from .core import CaseInsensitiveDict, CommandSink, Bot +from .errors import * from .models import Command, Context +from .parser import ( + Argument, + ConsumeType, + Converter, + Greedy, + Not, + Quotes, + StringReader, + converter, +) diff --git a/ferris/plugins/commands/core.py b/ferris/plugins/commands/core.py index 6647319..1dfcb91 100644 --- a/ferris/plugins/commands/core.py +++ b/ferris/plugins/commands/core.py @@ -1,70 +1,373 @@ from __future__ import annotations -import asyncio -from typing import Awaitable, Callable, Dict, Generator, TYPE_CHECKING, Sequence, Union +import functools -from ferris.client import Client -from .models import Command +from typing import ( + Any, + Awaitable, + Callable, + Collection, + Dict, + Generator, + List, + Optional, + overload, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) + +from .errors import CommandBasedError, CommandNotFound +from .models import Command, Context +from .parser import StringReader +from ...message import Message +from ...client import Client +from ...utils import ensure_async + +V = TypeVar('V') if TYPE_CHECKING: - from ferris.message import Message + from .models import CommandCallback + + DefaultT = TypeVar('DefaultT') + + BasePrefixT = Union[str, Collection[str]] + FunctionPrefixT = Callable[['Bot', Message], Awaitable[BasePrefixT]] + PrefixT = Union[BasePrefixT, FunctionPrefixT] + +__all__ = ( + 'CaseInsensitiveDict', + 'CommandSink', + 'Bot', +) + + +class CaseInsensitiveDict(Dict[str, V]): + """Represents a case-insensitive dictionary.""" + + def __getitem__(self, key: str) -> V: + return super().__getitem__(key.casefold()) + + def __setitem__(self, key: str, value: V) -> None: + super().__setitem__(key.casefold(), value) - _BasePrefixT = Union[str, Sequence[str]] - PrefixT = Union[ - _BasePrefixT, - Callable[['Bot', Message], Union[_BasePrefixT, Awaitable[_BasePrefixT]]], - ] + def __delitem__(self, key: str) -> None: + return super().__delitem__(key.casefold()) + + def __contains__(self, key: str) -> bool: + return super().__contains__(key.casefold()) + + def get(self, key: str, default: DefaultT = None) -> Union[V, DefaultT]: + return super().get(key.casefold(), default) + + def pop(self, key: str, default: DefaultT = None) -> Union[V, DefaultT]: + return super().pop(key.casefold(), default) + + get.__doc__ = dict.get.__doc__ + pop.__doc__ = dict.pop.__doc__ class CommandSink: """Represents a sink of commands. - This will both mixin to :class:`.Bot` and :class:`.CommandGroup`. + This will both mixin to :class:`~.Bot` and :class:`~.CommandGroup`. Attributes ---------- - mapping: Dict[str, :class:`.Command`] + command_mapping: Dict[str, :class:`.Command`] A full mapping of command names to commands. """ - def __init__(self) -> None: - self.mapping: Dict[str, Command] = {} + if TYPE_CHECKING: + command_mapping: Union[Dict[str, Command], CaseInsensitiveDict[Command]] + + def __init__(self, *, case_insensitive: bool = False) -> None: + mapping_factory = CaseInsensitiveDict if case_insensitive else dict + self.command_mapping = mapping_factory() - def walk_commands(self) -> Generator[Command]: + def walk_commands(self) -> Generator[Command, None, None]: """Returns a generator that walks through all of the commands this sink holds. Returns ------- - Generator[:class:`.Command`] + Generator[:class:`.Command`, None, None] """ seen = set() - for command in self.mapping.values(): + + for command in self.command_mapping.values(): if command not in seen: seen.add(command) yield command + def get_command(self, name: str, /, default: DefaultT = None) -> Union[Command, DefaultT]: + """Tries to get a command by it's name. + Aliases are supported. + + Parameters + ---------- + name: str + The name of the command you want to lookup. + default + What to return instead if the command was not found. + Defaults to ``None``. + + Returns + ------- + Optional[:class:`~.Command`] + The command found. + """ + return self.command_mapping.get(name, default) + + @overload + def command( + self, + name: str, + *, + alias: str = None, + brief: str = None, + description: str = None, + usage: str = None, + cls: Type[Command] = Command, + **kwargs + ) -> Callable[[CommandCallback], Command]: + ... + + @overload + def command( + self, + name: str, + *, + aliases: Collection[str] = None, + brief: str = None, + description: str = None, + usage: str = None, + cls: Type[Command] = Command, + **kwargs + ) -> Callable[[CommandCallback], Command]: + ... + + def command( + self, + name: str, + *, + alias: str = None, + aliases: Collection[str] = None, + brief: str = None, + description: str = None, + usage: str = None, + cls: Type[Command] = Command, + **kwargs + ) -> Callable[[CommandCallback], Command]: + """Returns a decorator that adds a command to this command sink.""" + + if alias and aliases: + raise ValueError('Only one of alias or aliases can be set.') + + if alias: + aliases = [alias] + elif not aliases: + aliases = [] + + def decorator(callback: CommandCallback, /) -> Command: + command = cls( + callback, + name=str(name), + aliases=aliases, + brief=brief, + description=description, + usage=usage, + **kwargs + ) + + self.command_mapping[name] = command + for alias in aliases: + self.command_mapping[alias] = command + + return command + + return decorator + @property - def commands(self) -> None: + def commands(self) -> List[Command]: """List[:class:`.Command`]: A list of the commands this sink holds.""" return list(self.walk_commands()) class Bot(Client, CommandSink): - """Represents a client connection to Discord with extra command handling support. + """Represents a bot with extra command handling support. Parameters ---------- + prefix + The prefix this bot will listen for. This is required. + case_insensitive: bool + Whether or not commands should be case-insensitive. loop: Optional[:class:`asyncio.AbstractEventLoop`] The event loop to use for the client. If not passed, then the default event loop is used. - prefix - The prefix the bot will listen for. This is required. max_messages_count: Optional[int] The maximum number of messages to store in the internal message buffer. Defaults to ``1000``. + + max_heartbeat_timeout: Optional[int] + The maximum timeout in seconds between sending a heartbeat to the server. + If heartbeat took longer than this timeout, the client will attempt to reconnect. """ - def __init__(self, prefix: PrefixT, **kwargs) -> None: + def __init__( + self, + prefix: PrefixT, + *, + case_insensitive: bool = False, + prefix_case_insensitive: bool = False, + strip_after_prefix: bool = False, + **kwargs + ) -> None: super().__init__(**kwargs) - CommandSink.__init__(self) + CommandSink.__init__(self, case_insensitive=case_insensitive) + + self.prefix: PrefixT = self._sanitize_prefix(prefix, case_insensitive=prefix_case_insensitive) + self.strip_after_prefix: bool = bool(strip_after_prefix) + + self._prefix_case_insensitive: bool = bool(prefix_case_insensitive) + self._default_case_insensitive: bool = bool(case_insensitive) + + @staticmethod + def _sanitize_prefix(prefix: Any, *, case_insensitive: bool = False, allow_callable: bool = True) -> PrefixT: + if prefix is None: + return '' + + if isinstance(prefix, str): + if case_insensitive: + return prefix.casefold() + + return prefix + + if isinstance(prefix, Collection): + try: + invalid = next(filter(lambda item: not isinstance(item, str), prefix)) + except StopIteration: + pass + else: + raise TypeError(f'Prefix collection must only consist of strings, not {type(invalid)!r}.') + + res = list(map(str.casefold, prefix)) if case_insensitive else list(prefix) + return sorted(res, key=len, reverse=True) + + if callable(prefix) and allow_callable: + return _ensure_casefold(ensure_async(prefix), case_insensitive=case_insensitive) + + raise TypeError( + f'Invalid prefix {prefix!r}. Only strings, collections of strings, ' + f'or functions that return them are allowed.' + ) + + async def get_prefix(self, message: Message, *, prefix: BasePrefixT = None) -> Optional[BasePrefixT]: + """Gets the prefix, or list of prefixes from the given message. + If the message does not start with a prefix, ``None`` is returned. + + Parameters + ---------- + message: :class:`~.Message` + The message to get the prefix from. + + Returns + ------- + Union[str, List[str]] + """ + prefix = self.prefix if prefix is None else prefix + content = message.content + + if self._prefix_case_insensitive: + content = content.casefold() + + if isinstance(prefix, str): + if content.startswith(prefix): + return prefix + else: + return None + + if isinstance(prefix, list): + try: + return next(filter(lambda pf: content.startswith(pf), prefix)) + except StopIteration: + return None + + if isinstance(prefix, callable): + prefix = await prefix(self, message) + return self.get_prefix(message, prefix=prefix) + + return None + + async def get_context(self, message: Message, *, cls: Type[Context] = Context) -> Context: + """Parses a :class:`~.Context` out of a message. + + If the message is not a command, + partial context with only the ``message`` parameter is returned. + + If an error occurs during parsing, + attributes of the returned context may remain as ``None``. + + Parameters + ---------- + message: :class:`~.Message` + The message to get the context from. + cls: Type[:class:`~.Context`] + The context subclass to use. Defaults to :class:`~.Context`. + + Returns + ------- + :class:`~.Context` + """ + ctx = cls(self, message) + prefix = await self.get_prefix(message) + + if prefix is None: + return ctx + + ctx.prefix = prefix + content = message.content[len(prefix):] + + if self.strip_after_prefix: + content = content.strip() + + reader = ctx.reader = StringReader(content) + ctx.invoked_with = word = reader.next_word(skip_first=False) + ctx.command = self.command_mapping.get(word) + return ctx + + async def invoke(self, ctx: Context) -> None: + """Parses and invokes the given context. + + Checks, cooldowns, hooks, etc. are ran here. + See :meth:`~.Context.reinvoke` for a version that bypasses these. + + Parameters + ---------- + ctx: :class:`~.Context` + The context to invoke. + """ + try: + if not ctx.command: + if ctx.invoked_with is not None: + raise CommandNotFound(ctx) + else: + return + + rest = ctx.reader.rest.strip() + await ctx.command.execute(ctx, rest) + + except CommandBasedError as exc: + self.dispatch('command_error', exc) + + +def _ensure_casefold(func: FunctionPrefixT, /, *, case_insensitive: bool = False) -> FunctionPrefixT: + @functools.wraps(func) + async def wrapper(bot: Bot, message: Message) -> BasePrefixT: + return bot._sanitize_prefix( + await func(bot, message), + case_insensitive=case_insensitive, + allow_callable=False + ) - self.prefix: PrefixT = prefix + return wrapper diff --git a/ferris/plugins/commands/errors.py b/ferris/plugins/commands/errors.py new file mode 100644 index 0000000..37cb877 --- /dev/null +++ b/ferris/plugins/commands/errors.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +from typing import Any, Optional, Tuple, TYPE_CHECKING +from ...utils import to_error_string + +if TYPE_CHECKING: + from .models import Context + from .parser import Argument + +__all__ = ( + 'CommandBasedError', + 'ArgumentParsingError', + 'ArgumentPreparationError', + 'ArgumentValidationError', + 'ConversionError', + 'ConversionFailure', + 'MissingArgumentError', + 'BadBooleanArgument', + 'BadLiteralArgument', + 'BlacklistedArgument', + 'BadArgument', + 'CommandNotFound', +) + + +class CommandBasedError(Exception): # Should inherit from your lib's base exception + """The base exception raised for errors related to the commands plugin.""" + + +class ArgumentParsingError(CommandBasedError): + """Raised when an error occurs during argument parsing.""" + + +class ArgumentValidationError(ArgumentParsingError): + """Raised when an argument fails validation. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + argument: :class:`~.Argument` + The argument that failed validation. + word: str + The literal string that did not make it past validation. + """ + + def __init__(self, ctx: Context, argument: Argument, word: str) -> None: + self.ctx: Context = ctx + self.argument: Argument = argument + self.word: str = word + super().__init__(f'Argument {argument.name!r} did not pass validation.') + + +class ArgumentPreparationError(ArgumentParsingError): + """Raised when an argument fails preparation. + + In other words, this exception is raised when an error + occurs during :attr:`~.Argument.prepare`. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + argument: :class:`~.Argument` + The argument that failed preparation. + word: str + The string of the word that could not be prepared. + error: :exc:`Exception` + The error that was raised. + """ + + def __init__(self, ctx: Context, argument: Argument, word: str, exc: Exception) -> None: + self.ctx: Context = ctx + self.argument: Argument = argument + self.word: str = word + self.error: Exception = exc + + super().__init__(f'Argument {argument.name!r} failed preparation: {to_error_string(exc)}') + + +class MissingArgumentError(ArgumentParsingError): + """Raised when an argument is missing. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + argument: :class:`~.Argument` + The argument that was missing. + """ + + def __init__(self, ctx: Context, argument: Argument) -> None: + self.ctx: Context = ctx + self.argument: Argument = argument + super().__init__(f'Missing required argument {argument.name!r}.') + + +class ConversionError(ArgumentParsingError): + """Raised when an error occurs during argument conversion.""" + + +class ConversionFailure(ConversionError): + """Raised when an unhandled error occurs during argument conversion. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + argument: :class:`~.Argument` + The argument that failed conversion. + word: str + The literal argument string that failed conversion. + errors: Tuple[:exc:`Exception`, ...] + The errors that were raised. + """ + + def __init__(self, ctx: Context, argument: Argument, word: str, *errors: Exception) -> None: + self.ctx: Context = ctx + self.argument: Argument = argument + self.word: str = word + + self.errors: Tuple[Exception, ...] = errors + super().__init__(f'Argument {argument.name!r} failed conversion: {to_error_string(self.error)}') + + @property + def error(self) -> Optional[Exception]: + """The most recent error raised, or ``None``.""" + try: + return self.errors[-1] + except IndexError: + return None + + +class BadArgument(ConversionError): + """A user defined error that should be raised if an argument cannot + be converted during argument conversion.""" + + +class BadBooleanArgument(BadArgument): + """Raised when converting an argument to a boolean fails.""" + + def __init__(self, argument: str) -> None: + super().__init__(f'{argument!r} cannot be represented as a boolean.') + + +class BadLiteralArgument(BadArgument): + """Raised when converting to a ``Literal`` generic fails.""" + + def __init__(self, argument: str, choices: Tuple[Any, ...]) -> None: + self.argument: str = argument + self.choices: Tuple[Any, ...] = choices + super().__init__(f'{argument!r} is not a valid choice.') + + +class BlacklistedArgument(BadArgument): + """Raised when converting to a ``Not`` generic is successful.""" + + def __init__(self, argument: str, blacklist: type) -> None: + self.argument: str = argument + self.blacklist: type = blacklist + super().__init__(f'{argument!r} should not be castable into {blacklist!r}.') + + +class CommandNotFound(CommandBasedError): + """Raised when a command is not found. + + Attributes + ---------- + ctx: :class:`~.Context` + The context that raised this error. + invoked_with: str + The string that did not correspond to a command. + """ + + def __init__(self, ctx: Context) -> None: + self.ctx: Context = ctx + self.invoked_with: str = ctx.invoked_with + super().__init__(f'Command {self.invoked_with!r} is not found.') diff --git a/ferris/plugins/commands/models.py b/ferris/plugins/commands/models.py index 07e118a..b0aaa06 100644 --- a/ferris/plugins/commands/models.py +++ b/ferris/plugins/commands/models.py @@ -1,34 +1,45 @@ from __future__ import annotations +import inspect + from typing import ( Any, Awaitable, + Callable, + Collection, + Dict, Generic, List, Optional, + Tuple, TYPE_CHECKING, TypeVar, - Union, ) +from .parser import Parser +from ...message import Message + R = TypeVar('R') if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec - from ferris.channel import Channel - from ferris.guild import Guild - from ferris.message import Message - from ferris.member import Member - from ferris.user import User + from .core import Bot + from .parser import StringReader P = ParamSpec('P') - CommandCallbackT = Callable[Concatenate['Context', P], Awaitable[R]] + CommandCallback = Callable[Concatenate['Context', P], Awaitable[R]] + ErrorCallback = Callable[['Context', Exception], Awaitable[Any]] else: P = TypeVar('P') +__all__ = ( + 'Command', + 'Context', +) + -class Command(Generic[P, R]): +class Command(Parser, Generic[P, R]): """Represents a command. Attributes @@ -37,97 +48,236 @@ class Command(Generic[P, R]): The name of this command. aliases: List[str] A list of aliases for this command. - callback: Callable[Concatenate[:class:`.Context`, P], Awaitable[R]] - The command callback for this command. + usage: str + The custom usage string for this command, or ``None``. + See :attr:`~.Command.signature` for an auto-generated version. + error_callback: Optional[Callable[[:class:`~.Context`, Exception], Any]] + The callback for when an error is raise during command invokation. + This could be ``None``. """ def __init__( - self, name: str, aliases: List[str], callback: CommandCallbackT + self, + callback: CommandCallback, + *, + name: str, + aliases: Collection[str], + brief: Optional[str] = None, + description: Optional[str] = None, + usage: Optional[str] = None, + **attrs ) -> None: + self.error_callback: Optional[ErrorCallback] = None + self.name: str = name - self.aliases: List[str] = aliases - self.callback: CommandCallbackT = callback - self.on_error: Optional[Callable[[Context, Exception], Any]] = None + self.aliases: List[str] = list(aliases) + self.usage: Optional[str] = usage + + self._brief: Optional[str] = brief + self._description: Optional[str] = description + + self._metadata: Dict[str, Any] = attrs + + super().__init__() + self.overload(callback) + + def __str__(self) -> str: + return self.qualified_name + + def __repr__(self) -> str: + return f'' + + def __hash__(self) -> int: + return hash(self.qualified_name) + + @property + def brief(self) -> str: + """str: A short description of this command. + + If no brief is set, the first line of the :attr:`~.Command.description` + is used instead. + """ + if self._brief is not None: + return self._brief + + try: + return self.description.splitlines()[0] + except IndexError: + return '' + + @property + def description(self) -> str: + """str: A detailed description of this command. + + If no description is set, the command's callback docstring is used instead. + If no docstring exists, an empty string is returned. + """ + if self._description is not None: + return self._description + + doc = self.callback.__doc__ + if not doc: + return '' + + return inspect.cleandoc(doc) + + @property + def qualified_name(self) -> str: + """str: The qualified name for this command.""" + return self.name + + @property + def signature(self) -> str: + """str: The signature, or "usage string" of this command. + + If :attr:`~.Command.usage` is set, it will be returned. + Else, it is automatically generated. + """ + return self.usage or super().signature + + def error(self, func: ErrorCallback) -> None: + """Registers an error handler for this command. + If no errors are raised here, the error is suppressed. + Else, `on_command_error` is raised. + + Function signature should be of ``async (ctx: Context, error: Exception) -> Any``. + """ + self.error_callback = func - def error(self, func: Callable[[Context, Exception], Any]) -> None: - """Adds an error handler to this command. + async def execute(self, ctx: Context, content: str) -> None: + """|coro| - Example - ------- - .. code:: python3 + Parses and executes this command with the given context and argument content. - @bot.command() - async def raise_error(ctx: Context) -> None: - int('a') # Will raise ValueError + .. note:: + The prefix and command must be removed from content before-hand. - @raise_error.error - async def error_handler(ctx: Context, exc: Exception) -> None: - if isinstance(exc, ValueError): - await ctx.send('Got ValueError!') + Parameters + ---------- + ctx: :class:`~.Context` + The context to invoke this command with. + content: str + The content of the arguments to parse. """ - self.on_error = func + try: + ctx.callback, _, _ = await self._parse(content, ctx=ctx) + except Exception as exc: + if self.error_callback: + try: + await self.error_callback(ctx, exc) + except Exception as new_exc: + exc = new_exc + else: + return + + ctx.bot.dispatch('command_error', ctx, exc) + else: + await ctx.reinvoke() - async def invoke(self, ctx: Context, *args: P.args, **kwargs: P.kwargs) -> None: - """Invokes this command. + async def invoke(self, ctx: Context, /, *args: P.args, **kwargs: P.kwargs) -> None: + """|coro| + + Invokes this command with the given context. Parameters ---------- - ctx: :class:`.Context` - The context to invoke the command with. - *args: P.args - The positional arguments to pass into the command callback. - **kwargs: P.kwargs - The keyword arguments to pass into the command callback. + ctx: :class:`~.Context` + The context to invoke this command with. + *args + The positional arguments to pass into the callback. + **kwargs + The keyword arguments to pass into the callback. """ try: - await self.callback(ctx, *args, **kwargs) - # dispatch `on_command` here... + ctx.bot.dispatch('command', ctx) + await ctx.callback(ctx, *args, **kwargs) except Exception as exc: - if self.on_error: + if self.error_callback: try: - await self.on_error(ctx, exc) + await self.error_callback(ctx, exc) except Exception as new_exc: exc = new_exc else: return - # dispatch `on_command_error` here... + + ctx.bot.dispatch('command_error', ctx, exc) else: - # dispatch `on_command_success` here... - ... + ctx.bot.dispatch('command_success', ctx) finally: - # dispatch `on_command_complete` here... - ... + ctx.bot.dispatch('command_complete', ctx) class Context: - """Represents the context for a command. + """Represents the context for when a command is invoked. Attributes ---------- - command: Optional[:class:`.Command`] - The command invoked. May be ``None``. + bot: :class:`~.Bot` + The bot that created this context. + message: :class:`~.Message` + The message that invoked this context. + prefix: str + The prefix used to invoke this command. + Could be, but rarely is ``None``. + invoked_with: str + The command name used to invoke this command. + This could be used to determine which alias invoked this command. + Could be ``None``. + command: :class:`~.Command` + The command invoked. Could be ``None``. + callback: Callable + The parsed command callback that will be used. Could be ``None``. + args: tuple + The arguments used to invoke this command. Could be ``None``. + kwargs: Dict[str, Any] + The keyword arguments used to invoke this command. Could be ``None``. + reader: :class:`~.StringReader` + The string-reader that was used to parse this command. Could be ``None``. """ - def __init__(self, *, message: Message) -> None: - self._message: Message = message - self.command: Command = None + def __init__(self, bot: Bot, message: Message) -> None: + self.bot: Bot = bot + self.message: Message = message - @property - def message(self) -> Message: - """:class:`~.Message`: The invocation message for this context.""" - return self._message + self.prefix: Optional[str] = None + self.invoked_with: Optional[str] = None - @property - def author(self) -> Union[User, Member]: - """Union[:class:`~.User`, :class:`~.Member`]: The author of this context.""" - return self._message.author + self.command: Optional[Command] = None + self.callback: Optional[CommandCallback] = None + self.args: Optional[tuple] = None + self.kwargs: Optional[Dict[str, Any]] = None + self.reader: Optional[StringReader] = None - @property - def channel(self) -> Channel: - """:class:`~.Channel`: The channel of this context.""" - return self._message.channel + def __repr__(self) -> str: + return f'' - @property - def guild(self) -> Guild: - """:class:`~.Guild`: The guild of this context.""" - return self._message.guild + async def invoke(self, command: Command[P, Any], /, *args: P.args, **kwargs: P.kwargs) -> None: + """|coro| + + Invokes the given command with this context. + + .. note:: No checks will be called here. + + Parameters + ---------- + command: :class:`~.Command` + The context to invoke this command with. + *args + The positional arguments to pass into the callback. + **kwargs + The keyword arguments to pass into the callback. + """ + self.command = command + self.args = args + self.kwargs = kwargs + + await command.invoke(self, *args, **kwargs) + + async def reinvoke(self) -> None: + """|coro| + + Re-invokes this command with the same arguments. + + .. note:: No checks will be called here. + """ + await self.invoke(self.command, *self.args, **self.kwargs) diff --git a/ferris/plugins/commands/parser.py b/ferris/plugins/commands/parser.py new file mode 100644 index 0000000..6b8ca5d --- /dev/null +++ b/ferris/plugins/commands/parser.py @@ -0,0 +1,937 @@ +# Improved version of https://github.com/wumpus-py/argument-parsing/blob/master/argument_parser/core.py + +from __future__ import annotations + +import inspect +from abc import ABC, abstractmethod +from enum import Enum +from itertools import chain +from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Dict, Generic, + Iterable, List, Literal, Optional, Tuple, Type, TypeVar, + Union, overload) + +from ...utils import ensure_async +from .errors import * + +ConverterOutputT = TypeVar('ConverterOutputT') +GreedyT = TypeVar('GreedyT') +LiteralT = TypeVar('LiteralT') +NotT = TypeVar('NotT') + +if TYPE_CHECKING: + from .models import Context + + ArgumentPrepareT = Callable[[str], str] + ConverterT = Union['Converter', Type['Converter'], Callable[[str], ConverterOutputT]] + ParserCallback = Callable[[Context, Any, ...], Any] + + ArgumentT = TypeVar('ArgumentT', bound='Argument') + BlacklistT = TypeVar('BlacklistT', bound=ConverterT) + ParserT = TypeVar('ParserT', bound=Union['_Subparser', 'Parser']) + +_NoneType: Type[None] = type(None) + +__all__ = ( + 'Parser', + 'StringReader', + 'Greedy', + 'Not', + 'ConsumeType', + 'Quotes', + 'Argument', + 'Converter', + 'converter', +) + + +class _NullType: + def __bool__(self) -> bool: + return False + + def __repr__(self) -> str: + return 'NULL' + + __str__ = __repr__ + + +_NULL = _NullType() + + +class ConsumeType(Enum): + """|enum| + + An enumeration of argument consumption types. + + Attributes + ---------- + default + The default and "normal" consumption type. + consume_rest + Consumes the string, including quotes, until the end. + list + Consumes in a similar fashion to the default consumption type, + but will consume like this all the way until the end. + If an error occurs, an error will be raised. + tuple + :attr:`~.ConsumeType.list` except that the result is a tuple, + rather than a list. + greedy + Like :attr:`~.ConsumeType.list`, but it stops consuming + when an argument fails to convert, rather than raising an error. + """ + + default = 'default' + consume_rest = 'consume_rest' + list = 'list' + tuple = 'tuple' + greedy = 'greedy' + + +class Quotes: + """Builtin quote mappings. All attributes are instances of ``Dict[str, str]``. + + Attributes + ---------- + default + The default quote mapping. + + ``"`` -> ``"`` + ``'`` -> ``'`` + extended + An extended quote mapping that supports + quotes from other languages/locales. + + ``'"'`` -> ``'"'`` + ``'`` -> ``'`` + ``'\u2018'`` -> ``'\u2019'`` + ``'\u201a'`` -> ``'\u201b'`` + ``'\u201c'`` -> ``'\u201d'`` + ``'\u201e'`` -> ``'\u201f'`` + ``'\u2e42'`` -> ``'\u2e42'`` + ``'\u300c'`` -> ``'\u300d'`` + ``'\u300e'`` -> ``'\u300f'`` + ``'\u301d'`` -> ``'\u301e'`` + ``'\ufe41'`` -> ``'\ufe42'`` + ``'\ufe43'`` -> ``'\ufe44'`` + ``'\uff02'`` -> ``'\uff02'`` + ``'\uff62'`` -> ``'\uff63'`` + ``'\xab'`` -> ``'\xbb'`` + ``'\u2039'`` -> ``'\u203a'`` + ``'\u300a'`` -> ``'\u300b'`` + ``'\u3008'`` -> ``'\u3009'`` + """ + + default = { + '"': '"', + "'": "'" + } + + extended = { + '"': '"', + "'": "'", + "\u2018": "\u2019", + "\u201a": "\u201b", + "\u201c": "\u201d", + "\u201e": "\u201f", + "\u2e42": "\u2e42", + "\u300c": "\u300d", + "\u300e": "\u300f", + "\u301d": "\u301e", + "\ufe41": "\ufe42", + "\ufe43": "\ufe44", + "\uff02": "\uff02", + "\uff62": "\uff63", + "\u2039": "\u203a", + "\u300a": "\u300b", + "\u3008": "\u3009", + "\xab": "\xbb", + } + + +class Greedy(Generic[GreedyT]): + """Represents a generic type annotation that sets the + consume-type of the argument to :attr:`~.ConsumeType.greedy`. + + Examples + -------- + .. code:: python3 + + @command() + async def massban(ctx, members: Greedy[Member], *, reason=None): + \"""Bans multiple people at once.\""" + for member in members: + await member.ban(reason=reason) + """ + + +class Not(Generic[NotT]): + """A special generic type annotation that fails conversion + if the given converter converts successfully. + + Examples + -------- + .. code:: python3 + + @command() + async def purchase(ctx, item: Not[int], quantity: int): + \"""Purchases an item.\""" + + @purchase.overload + async def purchase(ctx, quantity: int, *, item: str): + return await purchase.callback(ctx, item, quantity) + """ + + +class Converter(Generic[ConverterOutputT], ABC): + """A class that aids in making class-based converters.""" + + __is_converter__: bool = True + + @abstractmethod + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: + """|coro| + + The core conversion of the argument. + This must be implemented, or :exc:`NotImplementedError` will be raised. + + Parameters + ---------- + ctx: :class:`~.Context` + The command context being parsed. + Note that some attributes could be ``None``. + argument: str + The argument to convert. + + Returns + ------- + Any + """ + raise NotImplementedError + + # noinspection PyUnusedLocal + async def validate(self, ctx: Context, argument: str) -> bool: + """|coro| + + The argument validation check to use. + This will be called before convert and raise a :exc:`~.ValidationError` + if it fails. + + This exists to encourage cleaner code. + + Parameters + ---------- + ctx: :class:`~.Context` + The command context being parsed. + Note that some attributes could be ``None``. + argument: str + The argument to validate. + + Returns + ------- + bool + """ + return True + + +class LiteralConverter(Converter[LiteralT]): + def __init__(self, *choices: LiteralT) -> None: + self._valid: Tuple[LiteralT, ...] = choices + + def __repr__(self) -> None: + return f'<{self.__class__.__name__} valid={self._valid!r}>' + + async def convert(self, ctx: Context, argument: str) -> LiteralT: + for possible in self._valid: + p_type = type(possible) + try: + casted = await _convert_one(ctx, argument, p_type) + except ConversionError: + continue + else: + if casted not in self._valid: + raise BadLiteralArgument(argument, self._valid) + + return casted + + raise BadLiteralArgument(argument, self._valid) + + +class _Not(Converter[str]): + def __init__(self, *entities: BlacklistT) -> None: + self._blacklist: Tuple[BlacklistT, ...] = entities + + def __repr__(self) -> None: + return f'<{self.__class__.__name__} blacklist={self._blacklist!r}>' + + async def convert(self, ctx: Context, argument: str) -> str: + for entity in self._blacklist: + try: + await _convert_one(ctx, argument, entity) + except ConversionError: + return argument + else: + raise BlacklistedArgument(argument, entity) + + +_SC = 'Tuple[Tuple[ConverterT], Optional[bool], Optional[ConsumeType]]' + + +def _sanitize_converter(converter: ConverterT, /, optional: bool = None, consume: ConsumeType = None) -> _SC: + origin = getattr(converter, '__origin__', False) + args = getattr(converter, '__args__', False) + + if origin and args: + if origin is Union: + if _NoneType in args or None in args: + # This is an optional type. + optional = True + args = tuple(arg for arg in args if arg is not _NoneType) + + return tuple( + chain.from_iterable(_sanitize_converter(arg)[0] for arg in args) + ), optional, consume + + if origin is Literal: + converter = LiteralConverter(*args) + + if origin is List: + consume = ConsumeType.list + return _sanitize_converter(args[0], optional, consume) + + if origin is Greedy: + consume = ConsumeType.greedy + + if origin is Not: + converter = _Not(*args) + + if inspect.isclass(converter) and issubclass(converter, Converter): + converter = converter() + + return (converter,), optional, consume + + +def _convert_bool(argument: str) -> bool: + if isinstance(argument, bool): + return argument + + argument = argument.lower() + if argument in {'true', 't', 'yes', 'y', 'on', 'enable', 'enabled', '1'}: + return True + if argument in {'false', 'f', 'no', 'n', 'off', 'disable', 'disabled', '0'}: + return False + + raise BadBooleanArgument(argument) + + +async def _convert_one(ctx: Context, argument: str, converter: ConverterT) -> ConverterOutputT: + if converter is bool: + return _convert_bool(converter) + + try: + if getattr(converter, '__is_converter__', False): + return await converter.convert(ctx, argument) + + return converter(argument) + + except Exception as exc: + raise ConversionError(exc) + + +async def _convert(ctx: Context, argument: Argument, word: str, converters: Iterable[ConverterT]) -> ConverterOutputT: + errors = [] + for converter in converters: + try: + result = await _convert_one(ctx, word, converter) + + if getattr(converter, '__is_converter__', False) and not await converter.validate(ctx, result): + raise ArgumentValidationError(ctx, argument, word) + + return result + + except ConversionError as exc: + errors.append(exc) + + raise ConversionFailure(ctx, argument, word, *errors) + + +def _prepare(ctx: Context, argument: Argument, word: str) -> str: + callback = argument.prepare + + if callback is None: + return word + + try: + return str(callback(word)) + except Exception as exc: + raise ArgumentPreparationError(ctx, argument, word, exc) + + +class Argument: + """Represents a positional argument. + + Although these are automatically constructed, you can + explicitly construct these and use them as type-annotations + in your command signatures. + + Parameters + ---------- + converter + The converter for this argument. This cannot be mixed with `converters`. + *converters + A list of converters for this argument. This cannot be mixed with `converter`. + name: str + The name of this argument. + signature: str + A custom signature for this argument. Will be auto-generated if not given. + default + The default value of this argument. + optional: bool + Whether or not this argument is optional. + description: str + A description about this argument. + consume_type: :class:`~.ConsumeType` + The consumption type of this argument. + quoted: bool + Whether or not this argument can be pass in with quotes. + quotes: Dict[str, str] + A mapping of start-quotes to end-quotes + of all supported quotes for this argument. + **kwargs + Extra kwargs to pass in for metadata. + + Attributes + ---------- + name: str + The name of this argument. + default + The default value of this argument. + optional: bool + Whether or not this argument is optional. + description: str + A description about this argument. + consume_type: :class:`~.ConsumeType` + The consumption type of this argument. + quoted: bool + Whether or not this argument can be pass in with quotes. + quotes: Dict[str, str] + A mapping of start-quotes to end-quotes + of all supported quotes for this argument. + prepare: Callable[[str], str] + A callable that takes a string, which will + prepare the argument before conversion. + + Examples + -------- + .. code:: python3 + + @command() + async def translate( + ctx, + language: Argument( + Literal['spanish', 'french', 'chinese'], + prepare=str.lower + ), + text: Argument(str, consume_type=ConsumeType.consume_rest) + ): + ... + """ + + if TYPE_CHECKING: + @overload + def __init__( + self, + /, + converter: ConverterT, + *, + name: str = None, + signature: str = None, + default: Any = _NULL, + optional: bool = None, + description: str = None, + consume_type: Union[ConsumeType, str] = ConsumeType.default, + quoted: bool = None, + quotes: Dict[str, str] = None, + prepare: ArgumentPrepareT = None, + **kwargs + ) -> None: + ... + + @overload + def __init__( + self, + /, + *converters: ConverterT, + name: str = None, + alias: str = None, + signature: str = None, + default: Any = _NULL, + optional: bool = None, + description: str = None, + consume_type: Union[ConsumeType, str] = ConsumeType.default, + quoted: bool = None, + quotes: Dict[str, str] = None, + prepare: ArgumentPrepareT = None, + **kwargs + ) -> None: + ... + + def __init__( + self, + /, + *converters: ConverterT, + name: str = None, + signature: str = None, + default: Any = _NULL, + optional: bool = None, + description: str = None, + converter: ConverterT = None, + consume_type: Union[ConsumeType, str] = ConsumeType.default, + quoted: bool = None, + quotes: Dict[str, str] = None, + prepare: ArgumentPrepareT = None, + **kwargs + ) -> None: + actual_converters = str, + + if converters and converter is not None: + raise ValueError('Converter kwarg cannot be used when they are already passed as positional arguments.') + + if len(converters) == 1: + converter = converters[0] + converters = () + + if converters or converter: + if converters: + actual_converters, optional_, consume = _sanitize_converter(converters[0]) + actual_converters += _sanitize_converter(converters[1:])[0] + elif converter: + actual_converters, optional_, consume = _sanitize_converter(converter) + else: + raise ValueError('Parameter mismatch') + + optional = optional if optional is not None else optional_ + consume_type = consume_type if consume_type is not None else consume + + self._param_key: Optional[str] = None + self._param_kwarg_only: bool = False + self._param_var_positional: bool = False + + self.name: str = name + self.description: str = description + self.default: Any = default + self.prepare: ArgumentPrepareT = prepare + + self.consume_type: ConsumeType = ( + consume_type if isinstance(consume_type, ConsumeType) else ConsumeType(consume_type) + ) + + self.optional: bool = optional if optional is not None else False + self.quoted: bool = quoted if quoted is not None else consume_type is not ConsumeType.consume_rest + self.quotes: Dict[str, str] = quotes if quotes is not None else Quotes.default + + self._signature: str = signature + self._converters: Tuple[ConverterT, ...] = actual_converters + + self._kwargs: Dict[str, Any] = kwargs + + def __repr__(self) -> str: + return f'' + + def __hash__(self) -> int: + return hash(id(self)) + + @property + def converters(self) -> List[ConverterT]: + """List[Union[type, :class:`~.Converter`]]: A list of this argument's converters.""" + return list(self._converters) + + @property + def signature(self) -> str: + """str: The signature of this argument.""" + if self._signature is not None: + return self._signature + + start, end = '[]' if self.optional or self.default is not _NULL else '<>' + + suffix = '...' if self.consume_type in ( + ConsumeType.list, ConsumeType.tuple, ConsumeType.greedy + ) else '' + + default = f'={self.default}' if self.default is not _NULL else '' + return start + str(self.name) + default + suffix + end + + @signature.setter + def signature(self, value: str, /) -> str: + self._signature = value + + @classmethod + def _from_parameter(cls: Type[ArgumentT], param: inspect.Parameter, /) -> ArgumentT: + def finalize(argument: ArgumentT) -> ArgumentT: + if param.kind is param.KEYWORD_ONLY: + argument._param_kwarg_only = True + + if param.kind is param.VAR_POSITIONAL: + argument._param_var_positional = True + + argument._param_key = store = param.name + if not argument.name: + argument.name = store + + return argument + + kwargs = {'name': param.name} + + if param.annotation is not param.empty: + if isinstance(param.annotation, cls): + return finalize(param.annotation) + + kwargs['converter'] = param.annotation + + if param.default is not param.empty: + kwargs['default'] = param.default + + if param.kind is param.KEYWORD_ONLY: + kwargs['consume_type'] = ConsumeType.consume_rest + + elif param.kind is param.VAR_POSITIONAL: + kwargs['consume_type'] = ConsumeType.tuple + + return finalize(cls(**kwargs)) + + +class StringReader: + """Helper class to aid with parsing strings. + + Parameters + ---------- + string: str + The string to parse. + quotes: Dict[str, str] + A mapping of start-quotes to end-quotes. + Defaults to :attr:`~.Quotes.default` + """ + + class EOF: + """The pointer has gone past the end of the string.""" + + def __init__(self, string: str, /, *, quotes: Dict[str, str] = None) -> None: + self.quotes: Dict[str, str] = quotes or Quotes.default + self.buffer: str = string + self.index: int = -1 + + def seek(self, index: int, /) -> str: + """Seek to an index in the string. + + Parameters + ---------- + index: int + The index to seek to. + + Returns + ------- + str + """ + self.index = index + return self.current + + @property + def current(self) -> Union[str, Type[EOF]]: + """str: The current character this reader is pointing to.""" + try: + return self.buffer[self.index] + except IndexError: + return self.EOF + + @property + def eof(self) -> bool: + """bool: Whether or not this reader has reached the end of the string.""" + return self.current is self.EOF + + def previous_character(self) -> str: + """Seeks to the previous character in the string.""" + return self.seek(self.index - 1) + + def next_character(self) -> str: + """Seeks to the next character in the string.""" + return self.seek(self.index + 1) + + @property + def rest(self) -> str: + """str: The rest of the characters in the string.""" + result = self.buffer[self.index:] if self.index != -1 else self.buffer + self.index = len(self.buffer) # Force an EOF + return result + + @staticmethod + def _is_whitespace(char: str, /) -> bool: + if char is ...: + return False + return char.isspace() + + def skip_to_word(self) -> None: + """Skips to the beginning of the next word.""" + char = ... + while self._is_whitespace(char): + char = self.next_character() + + def next_word(self, *, skip_first: bool = True) -> str: + """Returns the next word in the string. + + Parameters + ---------- + skip_first: bool + Whether or not to call :meth:`~.StringReader.skip_to_word` + beforehand. Defaults to ``True``. + + Returns + ------- + str + """ + char = ... + buffer = '' + + if skip_first: + self.skip_to_word() + + while not self._is_whitespace(char): + char = self.next_character() + if self.eof: + return buffer + + buffer += char + + buffer = buffer[:-1] + return buffer + + def next_quoted_word(self, *, skip_first: bool = True) -> str: + """Returns the next quoted word in the string. + + Parameters + ---------- + skip_first: bool + Whether or not to call :meth:`~.StringReader.skip_to_word` + beforehand. Defaults to ``True``. + + Returns + ------- + str + """ + if skip_first: + self.skip_to_word() + + first_char = self.next_character() + + if first_char not in self.quotes: + self.previous_character() + return self.next_word(skip_first=False) + + end_quote = self.quotes[first_char] + + char = ... + buffer = '' + + while char != end_quote or self.buffer[self.index - 1] == '\\': + char = self.next_character() + if self.eof: + return buffer + + buffer += char + + self.next_character() + buffer = buffer[:-1] + return buffer + + +class _Subparser: + """Parses one specific overload.""" + + def __init__(self, arguments: List[Argument] = None, *, callback: ParserCallback = None): + self._arguments: List[Argument] = arguments or [] + self.callback: Optional[ParserCallback] = callback + + def __hash__(self) -> int: + return hash(id(self)) + + def add_argument(self, argument: Argument, /) -> None: + self._arguments.append(argument) + + @property + def signature(self) -> str: + """str: The signature (or "usage string") for this overload.""" + return ' '.join(arg.signature for arg in self._arguments) + + @classmethod + def from_function(cls: Type[P], func: ParserCallback, /) -> P: + """Creates a new :class:`~.Parser` from a function.""" + params = list(inspect.signature(func).parameters.values()) + + if len(params) < 1: + raise TypeError('Command callback must have at least one context parameter.') + + self = cls(callback=func) + for param in params[1:]: + self.add_argument(Argument._from_parameter(param)) + + return self + + async def parse(self, text: str, /, ctx: Context) -> Tuple[List[Any], Dict[str, Any]]: + # Return a tuple (args: list, kwargs: dict) + # Execute as callback(ctx, *args, **kwargs) + + args = ctx.args = [] + kwargs = ctx.kwargs = {} + reader = ctx.reader = StringReader(text) + + def append_value(argument: Argument, value: Any) -> None: + if argument._param_kwarg_only: + kwargs[argument._param_key] = value + elif argument._param_var_positional: + args.extend(value) + else: + args.append(value) + + i = 0 + for i, argument in enumerate(self._arguments, start=1): + if reader.eof: + i -= 1 + break + + start = reader.index + + if argument.consume_type not in (ConsumeType.consume_rest, ConsumeType.default): + # Either list, tuple, or greedy + result = [] + + while not reader.eof: + word = reader.next_quoted_word() if argument.quoted else reader.next_word() + word = _prepare(ctx, argument, word) + try: + word = await _convert(ctx, argument, word, argument.converters) + except ConversionError as exc: + if argument.consume_type is not ConsumeType.greedy: + raise exc + break + else: + result.append(word) + + if argument.consume_type is ConsumeType.tuple: + result = tuple(result) + + append_value(argument, result) + continue + + if argument.consume_type is ConsumeType.consume_rest: + word = reader.rest.strip() + else: + word = reader.next_quoted_word() if argument.quoted else reader.next_word() + + word = _prepare(ctx, argument, word) + + try: + word = await _convert(ctx, argument, word, argument.converters) + except ConversionError as exc: + if argument.optional: + default = argument.default if argument.default is not _NULL else None + append_value(argument, default) + reader.seek(start) + continue + + raise exc + else: + append_value(argument, word) + + for argument in self._arguments[i:]: + if argument.default is not _NULL: + append_value(argument, argument.default) + elif argument.optional: + append_value(argument, None) + else: + raise MissingArgumentError(ctx, argument) + + return args, kwargs + + +class Parser: + """The main class that parses arguments.""" + + def __init__(self, *, overloads: List[_Subparser] = None) -> None: + self._overloads: List[_Subparser] = overloads or [] + + @property + def _main_parser(self) -> _Subparser: + if not len(self._overloads): + self._overloads.append(_Subparser()) + + return self._overloads[0] + + @property + def callback(self) -> ParserCallback: + """Callable[[:class:`~.Context`, Any, ...], Any]: The callback of this command.""" + return self._main_parser.callback + + @callback.setter + def callback(self, func: ParserCallback) -> None: + func = ensure_async(func) + overload = _Subparser.from_function(func) + try: + self._overloads[0] = overload + except IndexError: + self._overloads = [overload] + + @property + def signature(self) -> str: + """str: The usage string of this command.""" + return self._main_parser.signature + + @property + def arguments(self) -> List[Argument]: + """List[:class:`~.Argument`]: A list of arguments this command takes.""" + return self._main_parser._arguments + + def overload(self: ParserT, func: ParserCallback, /) -> ParserT: + """Adds a command overload to this function.""" + func = ensure_async(func) + result = _Subparser.from_function(func) + self._overloads.append(result) + return self + + @classmethod + def from_function(cls: Type[ParserT], func: ParserCallback, /) -> ParserT: + func = ensure_async(func) + parser = _Subparser.from_function(func) + return cls(overloads=[parser]) + + async def _parse(self, text: str, /, ctx: Context) -> Tuple[ParserCallback, List[Any], Dict[str, Any]]: + errors = [] + for overload in self._overloads: + try: + return overload.callback, *await overload.parse(text, ctx=ctx) + except Exception as exc: + errors.append(exc) + + # Maybe do something with the other errors? + raise errors[0] + + +def converter(func: Callable[[Context, str], Awaitable[ConverterOutputT]]) -> Type[Converter[ConverterOutputT]]: + """A decorator that helps convert a function into a converter. + + Examples + -------- + .. code:: python3 + + @converter + async def reverse_string(ctx, argument): + return argument[::-1] + + @command('reverse') + async def reverse_command(ctx, *, text: reverse_string): + \"""Reverses a string of text.\""" + await ctx.send(text) + """ + func = ensure_async(func) + + async def convert(_, ctx: Context, argument: str) -> ConverterOutputT: + return await func(ctx, argument) + + return type('FunctionBasedConverter', (Converter,), {'convert': convert}) diff --git a/ferris/types/base.py b/ferris/types/base.py index 523f5f4..14c648e 100644 --- a/ferris/types/base.py +++ b/ferris/types/base.py @@ -1,4 +1,5 @@ -from typing import Protocol, Union, runtime_checkable, Optional +from typing import Protocol, Type, Union, runtime_checkable, Optional +from typing_extensions import TypeAlias __all__ = ('SupportsStr', 'SupportsId', 'Id', 'Snowflake') @@ -8,7 +9,7 @@ def __str__(self) -> str: ... -Snowflake = int +Snowflake: TypeAlias = int @runtime_checkable @@ -18,4 +19,4 @@ class SupportsId(Protocol): id: Snowflake -Id = Optional[Union[SupportsId, Snowflake]] +Id: TypeAlias = Optional[Union[SupportsId, Snowflake]] diff --git a/ferris/utils.py b/ferris/utils.py index c0d0618..0e05cd8 100644 --- a/ferris/utils.py +++ b/ferris/utils.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Callable, Iterable, Optional, TypeVar +from typing import Any, Callable, Iterable, Optional, TypeVar, TYPE_CHECKING, Awaitable, Union, overload from .types import Id, Snowflake @@ -9,8 +9,19 @@ 'get_snowflake_creation_date', 'find', 'dt_to_snowflake', + 'ensure_async', + 'to_error_string', ) + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + P = ParamSpec('P') +else: + P = TypeVar('P') + +R = TypeVar('R') + T = TypeVar('T', covariant=True) # try: @@ -124,3 +135,41 @@ def find(predicate: Callable[[T], Any], iterable: Iterable[T]) -> Optional[T]: if predicate(element): return element return None + +@overload +def ensure_async(func: Callable[P, Awaitable[R]], /) -> Callable[P, Awaitable[R]]: + ... + + +@overload +def ensure_async(func: Callable[P, R], /) -> Callable[P, Awaitable[R]]: + ... + + +def ensure_async(func: Union[Callable[P, Awaitable[R]], Callable[P, R]], /) -> Callable[P, Awaitable[R]]: + """Ensures that the given function is asynchronous. + In other terms, if the function is already async, it will stay the same. + Else, it will be converted into an async function. (Note that it will still be ran synchronously.) + """ + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + maybe_coro = func(*args, **kwargs) + + if inspect.isawaitable(maybe_coro): + return await maybe_coro + + return maybe_coro + + return wrapper + + +def to_error_string(exc: Exception, /) -> str: + """Formats the given error into ``{error name}: {error text}``, + e.g. ``ValueError: invalid literal for int() with base 10: 'a'``. + If no error text exists, only the error name will be returned, + e.g. just ``ValueError``. + """ + if str(exc).strip(): + return '{0.__class__.__name__}: {0}'.format(exc) + else: + return exc.__class__.__name__ \ No newline at end of file From 1c7c3f3a9b9fb2055ef29ccc66ceb4b7c3857176 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Wed, 24 Nov 2021 17:13:20 -0800 Subject: [PATCH 02/22] Fix python3.8 check --- ferris/utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/ferris/utils.py b/ferris/utils.py index 0e05cd8..fbd20aa 100644 --- a/ferris/utils.py +++ b/ferris/utils.py @@ -1,5 +1,11 @@ +from __future__ import annotations + +import functools +import sys +import inspect from datetime import datetime -from typing import Any, Callable, Iterable, Optional, TypeVar, TYPE_CHECKING, Awaitable, Union, overload +from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Iterable, + Optional, TypeVar, Union, overload) from .types import Id, Snowflake @@ -24,6 +30,13 @@ T = TypeVar('T', covariant=True) +if sys.version_info[:3] >= (3, 9, 0): + A = Callable[P, Awaitable[R]] # type: ignore + F = Callable[P, R] # type: ignore +else: + A = Callable[[P], Awaitable[R]] # type: ignore + F = Callable[[P], R] # type: ignore + # try: # import orjson @@ -146,7 +159,7 @@ def ensure_async(func: Callable[P, R], /) -> Callable[P, Awaitable[R]]: ... -def ensure_async(func: Union[Callable[P, Awaitable[R]], Callable[P, R]], /) -> Callable[P, Awaitable[R]]: +def ensure_async(func: Union[A, F], /) -> A: """Ensures that the given function is asynchronous. In other terms, if the function is already async, it will stay the same. Else, it will be converted into an async function. (Note that it will still be ran synchronously.) @@ -172,4 +185,4 @@ def to_error_string(exc: Exception, /) -> str: if str(exc).strip(): return '{0.__class__.__name__}: {0}'.format(exc) else: - return exc.__class__.__name__ \ No newline at end of file + return exc.__class__.__name__ From ffefe6963ac0accc0196963e6a2a86ed1cc2475a Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Sun, 5 Dec 2021 15:07:10 -0800 Subject: [PATCH 03/22] Fix a few type errors Co-authored-by jay3332 <40323796+jay3332@users.noreply.github.com> Signed-off-by Cryptex <64497526+Cryptex-github@users.noreply.github.com> --- ferris/plugins/commands/parser.py | 8 +++++--- ferris/types/__init__.py | 1 + ferris/types/plugins/__init__.py | 1 + ferris/types/plugins/commands/__init__.py | 11 +++++++++++ setup.py | 2 +- 5 files changed, 19 insertions(+), 4 deletions(-) create mode 100644 ferris/types/plugins/__init__.py create mode 100644 ferris/types/plugins/commands/__init__.py diff --git a/ferris/plugins/commands/parser.py b/ferris/plugins/commands/parser.py index 6b8ca5d..6c68e3b 100644 --- a/ferris/plugins/commands/parser.py +++ b/ferris/plugins/commands/parser.py @@ -11,6 +11,7 @@ Union, overload) from ...utils import ensure_async + from .errors import * ConverterOutputT = TypeVar('ConverterOutputT') @@ -20,13 +21,14 @@ if TYPE_CHECKING: from .models import Context + from ferris.types import ParserCallbackProto ArgumentPrepareT = Callable[[str], str] ConverterT = Union['Converter', Type['Converter'], Callable[[str], ConverterOutputT]] - ParserCallback = Callable[[Context, Any, ...], Any] + ParserCallback = ParserCallbackProto ArgumentT = TypeVar('ArgumentT', bound='Argument') - BlacklistT = TypeVar('BlacklistT', bound=ConverterT) + BlacklistT = TypeVar('BlacklistT', bound=ConverterT) # type: ignore ParserT = TypeVar('ParserT', bound=Union['_Subparser', 'Parser']) _NoneType: Type[None] = type(None) @@ -761,7 +763,7 @@ def signature(self) -> str: return ' '.join(arg.signature for arg in self._arguments) @classmethod - def from_function(cls: Type[P], func: ParserCallback, /) -> P: + def from_function(cls: Type[_Subparser], func: ParserCallback, /) -> _Subparser: """Creates a new :class:`~.Parser` from a function.""" params = list(inspect.signature(func).parameters.values()) diff --git a/ferris/types/__init__.py b/ferris/types/__init__.py index 4b4c132..215a43e 100644 --- a/ferris/types/__init__.py +++ b/ferris/types/__init__.py @@ -9,6 +9,7 @@ from .role import * from .user import * from .ws import * +from .plugins import * Data = Union[ AuthResponse, diff --git a/ferris/types/plugins/__init__.py b/ferris/types/plugins/__init__.py new file mode 100644 index 0000000..d86b8d2 --- /dev/null +++ b/ferris/types/plugins/__init__.py @@ -0,0 +1 @@ +from commands import * \ No newline at end of file diff --git a/ferris/types/plugins/commands/__init__.py b/ferris/types/plugins/commands/__init__.py new file mode 100644 index 0000000..3e93e18 --- /dev/null +++ b/ferris/types/plugins/commands/__init__.py @@ -0,0 +1,11 @@ +from typing import TYPE_CHECKING, Protocol, Tuple, Any, Dict + +__all__ = ('ParserCallbackProto',) + +if TYPE_CHECKING: + from ferris.plugins.commands import Context + + +class ParserCallbackProto(Protocol): + def __call__(self, ctx: Context, *args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any: + ... diff --git a/setup.py b/setup.py index 919d1b6..569f66d 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "Issue tracker": "https://github.com/FerrisChat/ferriswheel/issues/new", }, version=version, - packages=["ferris", "ferris.types", "ferris.plugins", "ferris.plugins.commands"], + packages=["ferris", "ferris.types", "ferris.plugins", "ferris.plugins.commands", "ferris.types.plugins", "ferris.types.plugins.commands"], license="MIT", description="An asynchronous Python wrapper around FerrisChat's API", long_description=readme, From 441d88c2a7ccef707ac0d4e757175d01db0368e2 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Sun, 5 Dec 2021 15:13:06 -0800 Subject: [PATCH 04/22] Fix import error --- ferris/plugins/commands/parser.py | 2 +- ferris/types/plugins/__init__.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ferris/plugins/commands/parser.py b/ferris/plugins/commands/parser.py index 6c68e3b..1ac9f2a 100644 --- a/ferris/plugins/commands/parser.py +++ b/ferris/plugins/commands/parser.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from .models import Context - from ferris.types import ParserCallbackProto + from ferris.types.plugins.commands import ParserCallbackProto ArgumentPrepareT = Callable[[str], str] ConverterT = Union['Converter', Type['Converter'], Callable[[str], ConverterOutputT]] diff --git a/ferris/types/plugins/__init__.py b/ferris/types/plugins/__init__.py index d86b8d2..e255a6d 100644 --- a/ferris/types/plugins/__init__.py +++ b/ferris/types/plugins/__init__.py @@ -1 +1,3 @@ -from commands import * \ No newline at end of file +from . import commands + +__all__ = ('commands') \ No newline at end of file From 0bf4648acf26e0b518a0f2987cce7be06291daa9 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Sun, 5 Dec 2021 15:14:54 -0800 Subject: [PATCH 05/22] Import annotations --- ferris/types/plugins/commands/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ferris/types/plugins/commands/__init__.py b/ferris/types/plugins/commands/__init__.py index 3e93e18..0c1a534 100644 --- a/ferris/types/plugins/commands/__init__.py +++ b/ferris/types/plugins/commands/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, Protocol, Tuple, Any, Dict __all__ = ('ParserCallbackProto',) From 8a84cc92178ad7698190092655ffc2ba2034b307 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Sun, 5 Dec 2021 15:16:29 -0800 Subject: [PATCH 06/22] a comma aaa --- ferris/types/plugins/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ferris/types/plugins/__init__.py b/ferris/types/plugins/__init__.py index e255a6d..7042230 100644 --- a/ferris/types/plugins/__init__.py +++ b/ferris/types/plugins/__init__.py @@ -1,3 +1,3 @@ from . import commands -__all__ = ('commands') \ No newline at end of file +__all__ = ('commands',) \ No newline at end of file From e61c49e49ab8a858cd685191251b1d8e97f5b3fe Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Sun, 5 Dec 2021 15:19:59 -0800 Subject: [PATCH 07/22] Reexport plugins in top level --- ferris/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ferris/__init__.py b/ferris/__init__.py index 14e70c6..a50480e 100644 --- a/ferris/__init__.py +++ b/ferris/__init__.py @@ -19,6 +19,7 @@ from .role import * from .user import * from .utils import * +from .plugins import * def create_user(username: str, password: str, email: str) -> PartialUser: From 10855ab4c920972b89dfad60661b687f50d5cf87 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Sun, 5 Dec 2021 15:29:32 -0800 Subject: [PATCH 08/22] Can I make a commit without breaking things --- ferris/plugins/commands/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ferris/plugins/commands/__init__.py b/ferris/plugins/commands/__init__.py index ede9dd9..ab93b40 100644 --- a/ferris/plugins/commands/__init__.py +++ b/ferris/plugins/commands/__init__.py @@ -1,4 +1,4 @@ -from . import errors, parser, utils +from . import errors, parser from .core import CaseInsensitiveDict, CommandSink, Bot from .errors import * from .models import Command, Context From dff322b8335b0e2e805b4f9eba39a66a0c1dd97a Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Sun, 5 Dec 2021 16:57:02 -0800 Subject: [PATCH 09/22] Add alias --- ferris/plugins/commands/models.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/ferris/plugins/commands/models.py b/ferris/plugins/commands/models.py index b0aaa06..14c6cd9 100644 --- a/ferris/plugins/commands/models.py +++ b/ferris/plugins/commands/models.py @@ -1,27 +1,16 @@ from __future__ import annotations import inspect +from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Collection, Dict, + Generic, List, Optional, Tuple, TypeVar) -from typing import ( - Any, - Awaitable, - Callable, - Collection, - Dict, - Generic, - List, - Optional, - Tuple, - TYPE_CHECKING, - TypeVar, -) - -from .parser import Parser from ...message import Message +from .parser import Parser R = TypeVar('R') if TYPE_CHECKING: + from ferris import Channel, Guild, Message, User from typing_extensions import Concatenate, ParamSpec from .core import Bot @@ -248,6 +237,11 @@ def __init__(self, bot: Bot, message: Message) -> None: self.kwargs: Optional[Dict[str, Any]] = None self.reader: Optional[StringReader] = None + self.channel: Optional[Channel] = message.channel + self.send: Optional[Callable[[str], Message]] = getattr(self.channel, 'send', None) + self.author: Optional[User] = message.author + self.guild: Optional[Guild] = message.guild + def __repr__(self) -> str: return f'' From aeb1886a8fb8ba36432a96a416a015a3a514a679 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Tue, 7 Dec 2021 19:30:55 -0800 Subject: [PATCH 10/22] Add some converters and remove __init__.py in plugins so is a namespace package which mean anyone can install plugins in it --- ferris/__init__.py | 1 - ferris/plugins/__init__.py | 1 - ferris/plugins/commands/converters.py | 66 +++++++++++++++++++++++++++ ferris/plugins/commands/errors.py | 3 +- ferris/plugins/commands/models.py | 19 ++++++-- ferris/plugins/commands/parser.py | 27 +++++++++-- 6 files changed, 108 insertions(+), 9 deletions(-) delete mode 100644 ferris/plugins/__init__.py create mode 100644 ferris/plugins/commands/converters.py diff --git a/ferris/__init__.py b/ferris/__init__.py index a50480e..14e70c6 100644 --- a/ferris/__init__.py +++ b/ferris/__init__.py @@ -19,7 +19,6 @@ from .role import * from .user import * from .utils import * -from .plugins import * def create_user(username: str, password: str, email: str) -> PartialUser: diff --git a/ferris/plugins/__init__.py b/ferris/plugins/__init__.py deleted file mode 100644 index 9f215bd..0000000 --- a/ferris/plugins/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import commands diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py new file mode 100644 index 0000000..72e1887 --- /dev/null +++ b/ferris/plugins/commands/converters.py @@ -0,0 +1,66 @@ +from ferris.errors import NotFound +from .parser import Converter, ConverterOutputT +from .models import Context +from .errors import BadArgument + +class ChannelConverter(Converter): + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: + c = None + + if argument.isdigit(): + id_ = int(argument) + + c = ctx.bot.get_channel(id_) + + if not c: + try: + c = await ctx.bot.fetch_channel(id_) + except NotFound: + if not c: + raise BadArgument(f'Channel {id_!r} not found') + else: + raise BadArgument(f'Argument must be an id.') + + return c + + +class GuildConverter(Converter): + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: + g = None + + if argument.isdigit(): + id_ = int(argument) + + g = ctx.bot.get_guild(id_) + + if not g: + try: + g = await ctx.bot.fetch_guild(id_) + except NotFound: + if not g: + raise BadArgument(f'Guild {id_!r} not found') + else: + raise BadArgument(f'Argument must be an id.') + + return g + + +class UserConverter(Converter): + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: + u = None + + if argument.isdigit(): + id_ = int(argument) + + u = ctx.bot.get_user(id_) + + if not u: + try: + u = await ctx.bot.fetch_user(id_) + except NotFound: + if not u: + raise BadArgument(f'User {id_!r} not found') + else: + raise BadArgument(f'Argument must be an id.') + + return u diff --git a/ferris/plugins/commands/errors.py b/ferris/plugins/commands/errors.py index 37cb877..669fc30 100644 --- a/ferris/plugins/commands/errors.py +++ b/ferris/plugins/commands/errors.py @@ -2,6 +2,7 @@ from typing import Any, Optional, Tuple, TYPE_CHECKING from ...utils import to_error_string +from ferris.errors import FerrisException if TYPE_CHECKING: from .models import Context @@ -23,7 +24,7 @@ ) -class CommandBasedError(Exception): # Should inherit from your lib's base exception +class CommandBasedError(FerrisException): """The base exception raised for errors related to the commands plugin.""" diff --git a/ferris/plugins/commands/models.py b/ferris/plugins/commands/models.py index 14c6cd9..9ad98a4 100644 --- a/ferris/plugins/commands/models.py +++ b/ferris/plugins/commands/models.py @@ -237,10 +237,7 @@ def __init__(self, bot: Bot, message: Message) -> None: self.kwargs: Optional[Dict[str, Any]] = None self.reader: Optional[StringReader] = None - self.channel: Optional[Channel] = message.channel self.send: Optional[Callable[[str], Message]] = getattr(self.channel, 'send', None) - self.author: Optional[User] = message.author - self.guild: Optional[Guild] = message.guild def __repr__(self) -> str: return f'' @@ -275,3 +272,19 @@ async def reinvoke(self) -> None: .. note:: No checks will be called here. """ await self.invoke(self.command, *self.args, **self.kwargs) + + @property + def channel(self) -> Optional[Channel]: + """Optional[:class:`~.Channel`]: The channel that this context was invoked in.""" + return self.message.channel + + @property + def author(self) -> Optional[User]: + """Optional[:class:`~.User`]: The user that this context was invoked by.""" + return self.message.author + + @property + def guild(self) -> Optional[Guild]: + """Optional[:class:`~.Guild`]: The guild that this context was invoked in.""" + return self.message.guild + diff --git a/ferris/plugins/commands/parser.py b/ferris/plugins/commands/parser.py index 1ac9f2a..bc0e4b8 100644 --- a/ferris/plugins/commands/parser.py +++ b/ferris/plugins/commands/parser.py @@ -10,19 +10,22 @@ Iterable, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload) -from ...utils import ensure_async - +from ...utils import P, ensure_async +from .converters import ChannelConverter, GuildConverter, UserConverter from .errors import * +from ferris import Channel, Guild, User + ConverterOutputT = TypeVar('ConverterOutputT') GreedyT = TypeVar('GreedyT') LiteralT = TypeVar('LiteralT') NotT = TypeVar('NotT') if TYPE_CHECKING: - from .models import Context from ferris.types.plugins.commands import ParserCallbackProto + from .models import Context + ArgumentPrepareT = Callable[[str], str] ConverterT = Union['Converter', Type['Converter'], Callable[[str], ConverterOutputT]] ParserCallback = ParserCallbackProto @@ -31,6 +34,12 @@ BlacklistT = TypeVar('BlacklistT', bound=ConverterT) # type: ignore ParserT = TypeVar('ParserT', bound=Union['_Subparser', 'Parser']) +CONVERTERS_MAPPING = { + Channel: ChannelConverter, + Guild: GuildConverter, + User: UserConverter, +} + _NoneType: Type[None] = type(None) __all__ = ( @@ -328,8 +337,20 @@ def _convert_bool(argument: str) -> bool: async def _convert_one(ctx: Context, argument: str, converter: ConverterT) -> ConverterOutputT: if converter is bool: return _convert_bool(converter) + + try: + module = converter.__module__ + except AttributeError: + pass + else: + if module is not None and module.startswith('ferris') and not module.endswith('Converter'): + converter = CONVERTERS_MAPPING.get(converter, converter) try: + if inspect.isclass(converter) and issubclass(converter, Converter): + if not inspect.ismethod(converter.convert): + converter = converter() + if getattr(converter, '__is_converter__', False): return await converter.convert(ctx, argument) From ee88a48dcef627827a27360599283c0869c84fb0 Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Mon, 13 Dec 2021 18:12:59 -0500 Subject: [PATCH 11/22] add remove_command() to CommandSink --- ferris/plugins/commands/core.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/ferris/plugins/commands/core.py b/ferris/plugins/commands/core.py index 1dfcb91..086ba17 100644 --- a/ferris/plugins/commands/core.py +++ b/ferris/plugins/commands/core.py @@ -193,6 +193,32 @@ def commands(self) -> List[Command]: """List[:class:`.Command`]: A list of the commands this sink holds.""" return list(self.walk_commands()) + def remove_command(self, name: str) -> Optional[Command]: + """ + Remove a command or alias from this command sink. + + Parameters + ----------- + name: :class:`str` + The name of the command or alias to remove. + + Returns + -------- + Optional[:class:`~.Command`] + The command that was removed. + If the name is invalid, ``None`` is returned instead. + """ + command = self.command_mapping.pop(name) + + if command: + if name in command.aliases: # This is an alias, so don't remove the command, only this alias. + return command + + for alias in command.aliases: + self.command_mapping.pop(alias) + + return command + class Bot(Client, CommandSink): """Represents a bot with extra command handling support. @@ -293,7 +319,7 @@ async def get_prefix(self, message: Message, *, prefix: BasePrefixT = None) -> O except StopIteration: return None - if isinstance(prefix, callable): + if callable(prefix): prefix = await prefix(self, message) return self.get_prefix(message, prefix=prefix) From 1a73b10a968a07409ab0853f33c6ab4be86bbbba Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Tue, 14 Dec 2021 21:15:20 -0500 Subject: [PATCH 12/22] Add converters: MemberConverter, MessageConverter, RoleConverter, InviteConverter --- ferris/plugins/commands/converters.py | 81 +++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index 72e1887..b16307e 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -2,6 +2,9 @@ from .parser import Converter, ConverterOutputT from .models import Context from .errors import BadArgument +from ferris.utils import find + +import re class ChannelConverter(Converter): async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: @@ -64,3 +67,81 @@ async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: raise BadArgument(f'Argument must be an id.') return u + + +class MemberConverter(Converter): + async def convert(self, ctx: Context, argument: str): + m = None + + if argument.isdigit(): + id_ = int(argument) + + m = ctx.guild.get_member(id_) + + if not m: + try: + m = await ctx.guild.fetch_member(id_) + except NotFound: + raise BadArgument(f'Member {id_!r} not found') + else: + raise BadArgument('Argument must be an id') + + return m + + +class MessageConverter(Converter): + async def convert(self, ctx: Context, argument: str): + m = None + + if argument.isdigit(): + id_ = int(argument) + + m = ctx.bot.get_message(id_) + + if not m: + try: + m = await ctx.bot.fetch_message(id_) + except NotFound: + raise BadArgument(f'Message {id_!r} not found') + else: # TODO: Convert from message url after webclient rewrite + raise BadArgument('Argument must be an id') + + return m + + +class RoleConverter(Converter): + async def convert(self, ctx: Context, argument: str): + r = None + + if argument.isdigit(): + id_ = int(argument) + + r = ctx.guild.get_role(id_) + + if not r: + try: + r = await ctx.guild.fetch_role(id_) + except NotFound: + raise BadArgument(f'Role {id_!r} not found') + else: + r = find(lambda r: r.name == argument, ctx.guild.roles) + + if not r: + raise BadArgument('Argument must be role id or name') + + return r + + +class InviteConverter(Converter): + async def convert(self, ctx: Context, argument: str): + invite_regex = re.compile(r'(https?:\/\/)?(www\.)?(ferris\.sh|ferris\.chat\/invite)\/([A-Za-z0-9]+)') + match = invite_regex.match(argument) + + if match: + argument = match.group(4) + + try: + i = await ctx.bot.fetch_invite(argument) + return i + except NotFound: + raise BadArgument('Argument passed must be a valid invite url or code') \ No newline at end of file From d1971b6b5f6294d4992081b4ea2492f1671c0248 Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Tue, 14 Dec 2021 21:18:28 -0500 Subject: [PATCH 13/22] Conformed --- ferris/plugins/commands/converters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index b16307e..a3a16fa 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -84,7 +84,7 @@ async def convert(self, ctx: Context, argument: str): except NotFound: raise BadArgument(f'Member {id_!r} not found') else: - raise BadArgument('Argument must be an id') + raise BadArgument('Argument must be an id.') return m @@ -104,7 +104,7 @@ async def convert(self, ctx: Context, argument: str): except NotFound: raise BadArgument(f'Message {id_!r} not found') else: # TODO: Convert from message url after webclient rewrite - raise BadArgument('Argument must be an id') + raise BadArgument('Argument must be an id.') return m @@ -127,7 +127,7 @@ async def convert(self, ctx: Context, argument: str): r = find(lambda r: r.name == argument, ctx.guild.roles) if not r: - raise BadArgument('Argument must be role id or name') + raise BadArgument('Argument must be role id or name.') return r @@ -144,4 +144,4 @@ async def convert(self, ctx: Context, argument: str): i = await ctx.bot.fetch_invite(argument) return i except NotFound: - raise BadArgument('Argument passed must be a valid invite url or code') \ No newline at end of file + raise BadArgument('Argument passed must be a valid invite url or code.') From bde914257f8c8db74d1f103d16514ff8791aae8b Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Tue, 14 Dec 2021 21:32:51 -0500 Subject: [PATCH 14/22] Type-hint output of convert() method --- ferris/plugins/commands/converters.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index a3a16fa..b44b822 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -70,7 +70,7 @@ async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: class MemberConverter(Converter): - async def convert(self, ctx: Context, argument: str): + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: m = None if argument.isdigit(): @@ -90,7 +90,7 @@ async def convert(self, ctx: Context, argument: str): class MessageConverter(Converter): - async def convert(self, ctx: Context, argument: str): + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: m = None if argument.isdigit(): @@ -110,7 +110,7 @@ async def convert(self, ctx: Context, argument: str): class RoleConverter(Converter): - async def convert(self, ctx: Context, argument: str): + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: r = None if argument.isdigit(): @@ -133,7 +133,7 @@ async def convert(self, ctx: Context, argument: str): class InviteConverter(Converter): - async def convert(self, ctx: Context, argument: str): + async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: invite_regex = re.compile(r'(https?:\/\/)?(www\.)?(ferris\.sh|ferris\.chat\/invite)\/([A-Za-z0-9]+)') match = invite_regex.match(argument) @@ -144,4 +144,4 @@ async def convert(self, ctx: Context, argument: str): i = await ctx.bot.fetch_invite(argument) return i except NotFound: - raise BadArgument('Argument passed must be a valid invite url or code.') + raise BadArgument('Argument must be a valid invite url or code.') From fca035cab026d7e856de9852540b593f23c0a3fa Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Tue, 14 Dec 2021 21:53:22 -0500 Subject: [PATCH 15/22] Fix types --- ferris/plugins/commands/converters.py | 35 ++++++++++++++------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index b44b822..795ce7a 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -1,13 +1,14 @@ -from ferris.errors import NotFound -from .parser import Converter, ConverterOutputT +from ...errors import NotFound +from ...utils import find +from ... import Channel, Guild, User, Member, Message, Role, Invite +from .parser import Converter from .models import Context from .errors import BadArgument -from ferris.utils import find import re -class ChannelConverter(Converter): - async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: +class ChannelConverter(Converter[Channel]): + async def convert(self, ctx: Context, argument: str): c = None if argument.isdigit(): @@ -27,8 +28,8 @@ async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: return c -class GuildConverter(Converter): - async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: +class GuildConverter(Converter[Guild]): + async def convert(self, ctx: Context, argument: str): g = None if argument.isdigit(): @@ -48,8 +49,8 @@ async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: return g -class UserConverter(Converter): - async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: +class UserConverter(Converter[User]): + async def convert(self, ctx: Context, argument: str): u = None if argument.isdigit(): @@ -69,8 +70,8 @@ async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: return u -class MemberConverter(Converter): - async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: +class MemberConverter(Converter[Member]): + async def convert(self, ctx: Context, argument: str): m = None if argument.isdigit(): @@ -89,8 +90,8 @@ async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: return m -class MessageConverter(Converter): - async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: +class MessageConverter(Converter[Message]): + async def convert(self, ctx: Context, argument: str): m = None if argument.isdigit(): @@ -109,8 +110,8 @@ async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: return m -class RoleConverter(Converter): - async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: +class RoleConverter(Converter[Role]): + async def convert(self, ctx: Context, argument: str): r = None if argument.isdigit(): @@ -132,8 +133,8 @@ async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: return r -class InviteConverter(Converter): - async def convert(self, ctx: Context, argument: str) -> ConverterOutputT: +class InviteConverter(Converter[Invite]): + async def convert(self, ctx: Context, argument: str): invite_regex = re.compile(r'(https?:\/\/)?(www\.)?(ferris\.sh|ferris\.chat\/invite)\/([A-Za-z0-9]+)') match = invite_regex.match(argument) From 072634f2e6703b6dea41cc8c73ad820566d15f8e Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Tue, 14 Dec 2021 22:05:28 -0500 Subject: [PATCH 16/22] Prevent circular imports --- ferris/plugins/commands/converters.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index 795ce7a..241ad4a 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -1,9 +1,12 @@ from ...errors import NotFound from ...utils import find -from ... import Channel, Guild, User, Member, Message, Role, Invite -from .parser import Converter -from .models import Context from .errors import BadArgument +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ... import Channel, Guild, User, Member, Message, Role, Invite + from .parser import Converter + from .models import Context import re From ecfd09121b4957dd9a7cbab27e50473d91b2969e Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Wed, 15 Dec 2021 10:23:36 -0500 Subject: [PATCH 17/22] Add future import to make TYPE_CHECKING guard work. Order imports --- ferris/plugins/commands/converters.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index 241ad4a..1027ecf 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -1,14 +1,17 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +from .errors import BadArgument from ...errors import NotFound from ...utils import find -from .errors import BadArgument -from typing import TYPE_CHECKING + if TYPE_CHECKING: from ... import Channel, Guild, User, Member, Message, Role, Invite - from .parser import Converter from .models import Context - -import re + from .parser import Converter class ChannelConverter(Converter[Channel]): async def convert(self, ctx: Context, argument: str): From d2a540c63e692f2b3f7c869ca4f64f145fc5c6ac Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Thu, 16 Dec 2021 20:37:43 -0500 Subject: [PATCH 18/22] Make INVITE_REGEX a constant in utils.py as requested by @ Cryptex-github --- ferris/plugins/commands/converters.py | 6 ++---- ferris/utils.py | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index 1027ecf..5985fde 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -1,11 +1,10 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING from .errors import BadArgument from ...errors import NotFound -from ...utils import find +from ...utils import find, INVITE_REGEX if TYPE_CHECKING: @@ -141,8 +140,7 @@ async def convert(self, ctx: Context, argument: str): class InviteConverter(Converter[Invite]): async def convert(self, ctx: Context, argument: str): - invite_regex = re.compile(r'(https?:\/\/)?(www\.)?(ferris\.sh|ferris\.chat\/invite)\/([A-Za-z0-9]+)') - match = invite_regex.match(argument) + match = INVITE_REGEX.match(argument) if match: argument = match.group(4) diff --git a/ferris/utils.py b/ferris/utils.py index fbd20aa..78fff68 100644 --- a/ferris/utils.py +++ b/ferris/utils.py @@ -3,6 +3,8 @@ import functools import sys import inspect +import json +import re from datetime import datetime from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Optional, TypeVar, Union, overload) @@ -45,15 +47,13 @@ # import json -import json - HAS_ORJSON = False FERRIS_EPOCH_MS: int = 1_577_836_800_000 - FERRIS_EPOCH: int = 1_577_836_800 +INVITE_REGEX: re.Pattern = re.compile(r'(https?:\/\/)?(www\.)?(ferris\.sh|ferris\.chat\/invite)\/([A-Za-z0-9]+)') if HAS_ORJSON: From 12705143f3d2a6bd8be8aa7710982284171ee4fb Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Thu, 16 Dec 2021 20:41:44 -0500 Subject: [PATCH 19/22] sort imports --- ferris/plugins/commands/converters.py | 7 +++---- ferris/utils.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index 5985fde..6b95fb7 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -2,13 +2,12 @@ from typing import TYPE_CHECKING -from .errors import BadArgument from ...errors import NotFound -from ...utils import find, INVITE_REGEX - +from ...utils import INVITE_REGEX, find +from .errors import BadArgument if TYPE_CHECKING: - from ... import Channel, Guild, User, Member, Message, Role, Invite + from ... import Channel, Guild, Invite, Member, Message, Role, User from .models import Context from .parser import Converter diff --git a/ferris/utils.py b/ferris/utils.py index 78fff68..18b20e3 100644 --- a/ferris/utils.py +++ b/ferris/utils.py @@ -1,10 +1,10 @@ from __future__ import annotations import functools -import sys import inspect import json import re +import sys from datetime import datetime from typing import (TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Optional, TypeVar, Union, overload) From 7ef8677d8491f841954e5a83347d77f54770f59a Mon Sep 17 00:00:00 2001 From: DistortedPumpkin Date: Thu, 16 Dec 2021 20:54:16 -0500 Subject: [PATCH 20/22] The oh so great maintainers asked me to walrus that --- ferris/plugins/commands/converters.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index 6b95fb7..c6974a0 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -139,13 +139,10 @@ async def convert(self, ctx: Context, argument: str): class InviteConverter(Converter[Invite]): async def convert(self, ctx: Context, argument: str): - match = INVITE_REGEX.match(argument) - - if match: + if match := INVITE_REGEX.match(argument): argument = match.group(4) - + try: - i = await ctx.bot.fetch_invite(argument) - return i + return await ctx.bot.fetch_invite(argument) except NotFound: raise BadArgument('Argument must be a valid invite url or code.') From 7527a53cfe072c498e2aba7345212d5b3c6509b0 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Thu, 16 Dec 2021 20:14:41 -0800 Subject: [PATCH 21/22] Add `__all__` --- ferris/plugins/commands/converters.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ferris/plugins/commands/converters.py b/ferris/plugins/commands/converters.py index c6974a0..452a447 100644 --- a/ferris/plugins/commands/converters.py +++ b/ferris/plugins/commands/converters.py @@ -6,11 +6,21 @@ from ...utils import INVITE_REGEX, find from .errors import BadArgument + +__all__ = ('ChannelConverter', + 'GuildConverter', + 'UserConverter', + 'MemberConverter', + 'MessageConverter', + 'RoleConverter', + 'InviteConverter') + if TYPE_CHECKING: from ... import Channel, Guild, Invite, Member, Message, Role, User from .models import Context from .parser import Converter + class ChannelConverter(Converter[Channel]): async def convert(self, ctx: Context, argument: str): c = None From f37a238aef4964c23585a9fe42db8caeb6f550c6 Mon Sep 17 00:00:00 2001 From: Cryptex <64497526+Cryptex-github@users.noreply.github.com> Date: Thu, 16 Dec 2021 20:15:54 -0800 Subject: [PATCH 22/22] Reexport --- ferris/plugins/commands/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ferris/plugins/commands/__init__.py b/ferris/plugins/commands/__init__.py index ab93b40..a671124 100644 --- a/ferris/plugins/commands/__init__.py +++ b/ferris/plugins/commands/__init__.py @@ -1,5 +1,6 @@ from . import errors, parser from .core import CaseInsensitiveDict, CommandSink, Bot +from .converters import * from .errors import * from .models import Command, Context from .parser import (