diff --git a/CHANGELOG.md b/CHANGELOG.md index 57a43b923..d9ff96516 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -116,6 +116,7 @@ The table below shows which release corresponds to each branch, and what date th - [#2647][2647] packing: Add `overlap` to overlap structures easily - [#2669][2669] asm: try native binutils before fallback architectures - [#2673][2673] Add libc module for libc-related functions +- [#2679][2679] Add type hints to parts of pwnlib.utils - [#2680][2680] Cleanup Python 2 legacy - [#2687][2687] Add (un)pack shorthands for 40-56 bit numbers `u48()`/`p48()` - [#2699][2699] Fix `tty` and `raw` arguments in `ssh.process()` @@ -169,6 +170,7 @@ The table below shows which release corresponds to each branch, and what date th [2647]: https://github.com/Gallopsled/pwntools/pull/2647 [2669]: https://github.com/Gallopsled/pwntools/pull/2669 [2673]: https://github.com/Gallopsled/pwntools/pull/2673 +[2679]: https://github.com/Gallopsled/pwntools/pull/2679 [2680]: https://github.com/Gallopsled/pwntools/pull/2680 [2687]: https://github.com/Gallopsled/pwntools/pull/2687 [2699]: https://github.com/Gallopsled/pwntools/pull/2699 diff --git a/pwnlib/util/crc/known.py b/pwnlib/util/crc/known.py index 35316b7ff..9d147de85 100644 --- a/pwnlib/util/crc/known.py +++ b/pwnlib/util/crc/known.py @@ -1,8 +1,9 @@ import os import re +from typing import Any -def generate(): +def generate() -> dict[str, dict]: """Generates a dictionary of all the known CRC formats from: https://reveng.sourceforge.io/crc-catalogue/all.htm @@ -15,7 +16,7 @@ def generate(): data = fd.read() out = {} - def fixup(s): + def fixup(s: str) -> Any | int | bool: if s == 'true': return True elif s == 'false': diff --git a/pwnlib/util/fiddling.py b/pwnlib/util/fiddling.py index cf8394769..38ec16fd0 100644 --- a/pwnlib/util/fiddling.py +++ b/pwnlib/util/fiddling.py @@ -6,6 +6,8 @@ import string from io import BytesIO +from typing import Any, Generator, Iterable, Optional, BinaryIO +from collections.abc import Sequence from pwnlib.context import LocalNoarchContext from pwnlib.context import context @@ -20,7 +22,7 @@ log = getLogger(__name__) -def unhex(s): +def unhex(s: Sequence) -> bytes: r"""unhex(s) -> str Hex-decodes a string. @@ -42,7 +44,7 @@ def unhex(s): s = '0' + s return binascii.unhexlify(s) -def enhex(x): +def enhex(x: Sequence) -> str: """enhex(x) -> str Hex-encodes a string. @@ -58,7 +60,7 @@ def enhex(x): return x -def hexstr(s, force=False): +def hexstr(s: bytes, force: bool = False) -> str: r""" hexstr(x, force=False) -> str @@ -82,7 +84,7 @@ def hexstr(s, force=False): return out.decode() -def urlencode(s): +def urlencode(s: str) -> str: """urlencode(s) -> str URL-encodes a string. @@ -94,7 +96,7 @@ def urlencode(s): """ return ''.join(['%%%02x' % ord(c) for c in s]) -def urldecode(s, ignore_invalid = False): +def urldecode(s: str, ignore_invalid: bool = False) -> str: """urldecode(s, ignore_invalid = False) -> str URL-decodes a string. @@ -128,13 +130,13 @@ def urldecode(s, ignore_invalid = False): raise ValueError("Invalid input to urldecode") return res -def bits(s, endian = 'big', zero = 0, one = 1): +def bits(s: int | bytes, endian: str = 'big', zero: str = 0, one: str = 1) -> list[int | str]: """bits(s, endian = 'big', zero = 0, one = 1) -> list Converts the argument into a list of bits. Arguments: - s: A string or number to be converted into bits. + s: A bytestring or number to be converted into bits. endian (str): The binary endian, default 'big'. zero: The representing a 0-bit. one: The representing a 1-bit. @@ -185,7 +187,7 @@ def bits(s, endian = 'big', zero = 0, one = 1): return out -def bits_str(s, endian = 'big', zero = '0', one = '1'): +def bits_str(s: Sequence, endian: str = 'big', zero: str = '0', one: str = '1') -> str: """bits_str(s, endian = 'big', zero = '0', one = '1') -> str A wrapper around :func:`bits`, which converts the output into a string. @@ -199,7 +201,7 @@ def bits_str(s, endian = 'big', zero = '0', one = '1'): """ return ''.join(bits(s, endian, zero, one)) -def unbits(s, endian = 'big'): +def unbits(s: Iterable, endian: str = 'big') -> str: r"""unbits(s, endian = 'big') -> str Converts an iterable of bits into a string. @@ -247,7 +249,7 @@ def unbits(s, endian = 'big'): return out -def bitswap(s): +def bitswap(s: bytes) -> bytes: r"""bitswap(s) -> str Reverses the bits in every byte of a given string. @@ -265,7 +267,7 @@ def bitswap(s): return b''.join(out) -def bitswap_int(n, width): +def bitswap_int(n: int, width: int) -> str: """bitswap_int(n) -> int Reverses the bits of a numbers and returns the result as a new number. @@ -295,7 +297,7 @@ def bitswap_int(n, width): return int(s, 2) -def b64e(s): +def b64e(s: bytes) -> str: """b64e(s) -> str Base64 encodes a string @@ -310,7 +312,7 @@ def b64e(s): x = x.decode('ascii') return x -def b64d(s): +def b64d(s: str) -> bytes: """b64d(s) -> str Base64 decodes a string @@ -323,7 +325,7 @@ def b64d(s): return base64.b64decode(s) # misc binary functions -def xor(*args, **kwargs): +def xor(*args: tuple, **kwargs: dict[str, Any]) -> bytes: """xor(*args, cut = 'max') -> str Flattens its arguments using :func:`pwnlib.util.packing.flat` and @@ -375,15 +377,15 @@ def xor(*args, **kwargs): else: raise ValueError("Not a valid argument for 'cut'") - def get(n): + def get(n: int) -> bytes: rv = 0 for s in strs: rv ^= s[n%len(s)] return packing._p8lu(rv) return b''.join(map(get, range(cut))) -def xor_pair(data, avoid = b'\x00\n'): - r"""xor_pair(data, avoid = '\x00\n') -> None or (str, str) +def xor_pair(data: int | Sequence, avoid: bytes = b'\x00\n') -> Optional[tuple[str, str]]: + r"""xor_pair(data, avoid = '\\x00\\n') -> None or (str, str) Finds two strings that will xor into a given string, while only using a given alphabet. @@ -427,7 +429,7 @@ def xor_pair(data, avoid = b'\x00\n'): return res1, res2 -def xor_key(data, avoid=b'\x00\n', size=None): +def xor_key(data: str, avoid: bytes = b'\x00\n', size: int = None) -> Optional[tuple[str, str]]: r"""xor_key(data, size=None, avoid='\x00\n') -> None or (int, str) Finds a ``size``-width value that can be XORed with a string @@ -476,7 +478,7 @@ def xor_key(data, avoid=b'\x00\n', size=None): return result, xor(data, result) -def randoms(count, alphabet = string.ascii_lowercase): +def randoms(count: int, alphabet: str = string.ascii_lowercase) -> str: """randoms(count, alphabet = string.ascii_lowercase) -> str Returns a random string of a given length using only the specified alphabet. @@ -497,7 +499,7 @@ def randoms(count, alphabet = string.ascii_lowercase): return ''.join(random.choice(alphabet) for _ in range(count)) -def rol(n, k, word_size = None): +def rol(n: Sequence | int, k: int, word_size: int = None) -> str: """Returns a rotation by `k` of `n`. When `n` is a number, then means ``((n << k) | (n >> (word_size - k)))`` truncated to `word_size` bits. @@ -540,12 +542,12 @@ def rol(n, k, word_size = None): else: raise ValueError("rol(): 'n' must be an integer, string, list or tuple") -def ror(n, k, word_size = None): +def ror(n: Sequence | int, k: int, word_size: int = None) -> str: """A simple wrapper around :func:`rol`, which negates the values of `k`.""" return rol(n, -k, word_size) -def naf(n): +def naf(n: int) -> Generator[int, None, None]: """naf(int) -> int generator Returns a generator for the non-adjacent form (NAF[1]) of a number, `n`. If @@ -571,7 +573,7 @@ def naf(n): n = (n - z) // 2 yield z -def isprint(c): +def isprint(c: str | int) -> bool: """isprint(c) -> bool Return True if a character is printable""" @@ -581,7 +583,7 @@ def isprint(c): return c in t -def hexii(s, width = 16, skip = True): +def hexii(s: str, width: int = 16, skip: bool = True) -> str: """hexii(s, width = 16, skip = True) -> str Return a HEXII-dump of a string. @@ -597,7 +599,7 @@ def hexii(s, width = 16, skip = True): return hexdump(s, width, skip, True) -def _hexiichar(c): +def _hexiichar(c: int) -> str: HEXII = bytearray((string.punctuation + string.digits + string.ascii_letters).encode()) if c in HEXII: return ".%c " % c @@ -619,16 +621,16 @@ def _hexiichar(c): cyclic_pregen = b'' de_bruijn_gen = de_bruijn() -def sequential_lines(a,b): +def sequential_lines(a: int, b: int) -> bool: return (a+b) in cyclic_pregen -def update_cyclic_pregenerated(size): +def update_cyclic_pregenerated(size: int) -> None: global cyclic_pregen while size > len(cyclic_pregen): cyclic_pregen += packing._p8lu(next(de_bruijn_gen)) -def hexdump_iter(fd, width=16, skip=True, hexii=False, begin=0, style=None, - highlight=None, cyclic=False, groupsize=4, total=True): +def hexdump_iter(fd: BinaryIO, width: int = 16, skip: bool = True, hexii: bool = False, begin: int = 0, style: dict = None, + highlight: Iterable = None, cyclic: bool = False, groupsize: int = 4, total: bool = True) -> Generator[str, None, None]: r"""hexdump_iter(s, width = 16, skip = True, hexii = False, begin = 0, style = None, highlight = None, cyclic = False, groupsize=4, total = True) -> str generator @@ -688,7 +690,7 @@ def hexdump_iter(fd, width=16, skip=True, hexii=False, begin=0, style=None, marker = (style.get('marker') or (lambda s:s))('โ”‚') if not hexii: - def style_byte(by): + def style_byte(by: int) -> tuple[str, str]: hbyte = '%02x' % by b = packing._p8lu(by) abyte = chr(by) if isprint(b) else 'ยท' @@ -791,8 +793,8 @@ def style_byte(by): line = "%08x" % (begin + numb) yield line -def hexdump(s, width=16, skip=True, hexii=False, begin=0, style=None, - highlight=None, cyclic=False, groupsize=4, total=True): +def hexdump(s: bytes, width: int = 16, skip: bool = True, hexii: bool = False, begin: int = 0, style: dict = None, + highlight: Iterable = None, cyclic: bool = False, groupsize: int = 4, total: bool = True) -> str: r"""hexdump(s, width = 16, skip = True, hexii = False, begin = 0, style = None, highlight = None, cyclic = False, groupsize=4, total = True) -> str @@ -988,7 +990,7 @@ def hexdump(s, width=16, skip=True, hexii=False, begin=0, style=None, groupsize, total)) -def negate(value, width = None): +def negate(value: int, width: int = None) -> int: """ Returns the two's complement of 'value'. """ @@ -997,7 +999,7 @@ def negate(value, width = None): mask = ((1< int: """ Returns the binary inverse of 'value'. """ @@ -1007,7 +1009,7 @@ def bnot(value, width=None): return mask ^ value @LocalNoarchContext -def js_escape(data, padding=context.cyclic_alphabet[0:1], **kwargs): +def js_escape(data: bytes, padding: bytes = context.cyclic_alphabet[0:1], **kwargs: dict[str, Any]) -> bytes: r"""js_escape(data, padding=context.cyclic_alphabet[0:1], endian = None, **kwargs) -> str Pack data as an escaped Unicode string for use in JavaScript's `unescape()` function @@ -1049,7 +1051,7 @@ def js_escape(data, padding=context.cyclic_alphabet[0:1], **kwargs): return ''.join(f'%u{a:02x}{b:02x}' for a, b in iters.group(2, data)) @LocalNoarchContext -def js_unescape(s, **kwargs): +def js_unescape(s: str, **kwargs: dict[str, Any]) -> bytes: r"""js_unescape(s, endian = None, **kwargs) -> bytes Unpack an escaped Unicode string from JavaScript's `escape()` function @@ -1117,7 +1119,7 @@ def js_unescape(s, **kwargs): return b''.join(res) -def tty_escape(s, lnext=b'\x16', dangerous=bytes(bytearray(range(0x20)))): +def tty_escape(s: bytes, lnext: bytes = b'\x16', dangerous: bytes = bytes(bytearray(range(0x20)))) -> bytes: r"""tty_escape(s, lnext=b'\x16', dangerous=bytes(bytearray(range(0x20)))) -> bytes Escape data for terminal output. This is useful when sending data to a diff --git a/pwnlib/util/getdents.py b/pwnlib/util/getdents.py index 8214f650a..3311e84b7 100644 --- a/pwnlib/util/getdents.py +++ b/pwnlib/util/getdents.py @@ -65,7 +65,7 @@ class linux_dirent: d_type: Dtype d_name: str - def __init__(self, buf: bytes, is_dirent64: bool): + def __init__(self, buf: bytes, is_dirent64: bool) -> None: size_t = 8 if is_dirent64 else context.bytes self.d_ino = unpack(buf[0:size_t], size_t * 8) @@ -82,13 +82,13 @@ def __init__(self, buf: bytes, is_dirent64: bool): self.d_name = d_name.split(b'\x00', 1)[0].decode('utf-8') self.d_type = Dtype(d_type) - def __len__(self): + def __len__(self) -> int: return self.d_reclen - def __str__(self): + def __str__(self) -> str: return self.d_name - def __repr__(self): + def __repr__(self) -> str: return f'{self.d_type.name:<8}{self.d_name}' diff --git a/pwnlib/util/lists.py b/pwnlib/util/lists.py index 0c669f902..ed136e3a2 100644 --- a/pwnlib/util/lists.py +++ b/pwnlib/util/lists.py @@ -1,7 +1,9 @@ import collections +from collections.abc import Sequence +from typing import Any, Callable, Generator, Literal -def partition(lst, f, save_keys = False): +def partition(lst: list, f: Callable, save_keys: bool = False) -> list: """partition(lst, f, save_keys = False) -> list Partitions an iterable into sublists using a function to specify which @@ -34,7 +36,7 @@ def partition(lst, f, save_keys = False): else: return list(d.values()) -def group(n, lst, underfull_action = 'ignore', fill_value = None): +def group(n: int, lst: Sequence, underfull_action: str = 'ignore', fill_value: list | tuple | bytes | str = None) -> list: """group(n, lst, underfull_action = 'ignore', fill_value = None) -> list Split sequence into subsequences of given size. If the values cannot be @@ -92,7 +94,7 @@ def group(n, lst, underfull_action = 'ignore', fill_value = None): return out -def concat(l): +def concat(l: list) -> list: """concat(l) -> list Concats a list of lists into a list. @@ -110,7 +112,7 @@ def concat(l): return res -def concat_all(*args): +def concat_all(*args: tuple) -> list: """concat_all(*args) -> list Concats all the arguments together. @@ -121,7 +123,7 @@ def concat_all(*args): [0, 1, 2, 3, 4, 5, 6] """ - def go(arg, output): + def go(arg: tuple | list | Any, output: list) -> list: if isinstance(arg, (tuple, list)): for e in arg: go(e, output) @@ -131,7 +133,7 @@ def go(arg, output): return go(args, []) -def ordlist(s): +def ordlist(s: str) -> list[int]: """ordlist(s) -> list Turns a string into a list of the corresponding ascii values. @@ -143,7 +145,7 @@ def ordlist(s): """ return list(map(ord, s)) -def unordlist(cs): +def unordlist(cs: list[int]) -> str: """unordlist(cs) -> str Takes a list of ascii values and returns the corresponding string. @@ -155,7 +157,7 @@ def unordlist(cs): """ return ''.join(chr(c) for c in cs) -def findall(haystack, needle): +def findall(haystack: list, needle: Any) -> list | Any: """findall(l, e) -> l Generate all indices of needle in haystack, using the @@ -175,7 +177,7 @@ def findall(haystack, needle): >>> list(findall("aaabaaabc", "aab")) [1, 5] """ - def __kmp_table(W): + def __kmp_table(W: Any) -> list: pos = 1 cnd = 0 T = [] @@ -193,7 +195,7 @@ def __kmp_table(W): T.append(0) return T - def __kmp_search(S, W): + def __kmp_search(S: list, W: Any) -> Generator[Any | Literal[0], Any, None]: m = 0 i = 0 T = __kmp_table(W) @@ -208,7 +210,7 @@ def __kmp_search(S, W): m += i - T[i] i = max(T[i], 0) - def __single_search(S, w): + def __single_search(S: list, w: Any) -> Generator[int, Any, None]: for i, v in enumerate(S): if v == w: yield i diff --git a/pwnlib/util/net.py b/pwnlib/util/net.py index aa6f905e3..895e1e66e 100644 --- a/pwnlib/util/net.py +++ b/pwnlib/util/net.py @@ -1,11 +1,16 @@ import ctypes import ctypes.util import socket +from typing import Callable +from pwnlib import tubes +from pwnlib.log import getLogger from pwnlib.util.packing import p16 from pwnlib.util.packing import p32 from pwnlib.util.packing import pack +log = getLogger(__name__) + __all__ = ['getifaddrs', 'interfaces', 'interfaces4', 'interfaces6', 'sockaddr'] # /usr/src/linux-headers-3.12-1-common/include/uapi/linux/socket.h @@ -58,7 +63,7 @@ class struct_ifaddrs(ctypes.Structure): AddressFamily = getattr(socket, 'AddressFamily', int) -def sockaddr_fixup(saptr): +def sockaddr_fixup(saptr: ctypes.POINTER) -> tuple[int, dict]: family = AddressFamily(saptr.contents.sa_family) addr = {} if family == socket.AF_INET: @@ -73,7 +78,7 @@ def sockaddr_fixup(saptr): addr['scope_id'] = sa.sin6_scope_id return family, addr -def getifaddrs(): +def getifaddrs() -> list[dict]: """getifaddrs() -> dict list A wrapper for libc's ``getifaddrs``. @@ -122,7 +127,7 @@ def getifaddrs(): finally: freeifaddrs(ifaptr) -def interfaces(all = False): +def interfaces(all: bool = False) -> dict: """interfaces(all = False) -> dict Arguments: @@ -148,7 +153,7 @@ def interfaces(all = False): out = {k: v for k, v in out.items() if v} return out -def interfaces4(all = False): +def interfaces4(all: bool = False) -> dict: """interfaces4(all = False) -> dict As :func:`interfaces` but only includes IPv4 addresses and the lists in the @@ -174,7 +179,7 @@ def interfaces4(all = False): out[name] = addrs return out -def interfaces6(all = False): +def interfaces6(all: bool = False) -> dict: """interfaces6(all = False) -> dict As :func:`interfaces` but only includes IPv6 addresses and the lists in the @@ -200,7 +205,7 @@ def interfaces6(all = False): out[name] = addrs return out -def sockaddr(host, port, network = 'ipv4'): +def sockaddr(host: str, port: int, network: str = 'ipv4') -> tuple[bytes, int, int]: """sockaddr(host, port, network = 'ipv4') -> (data, length, family) Creates a sockaddr_in or sockaddr_in6 memory buffer for use in shellcode. @@ -237,13 +242,13 @@ def sockaddr(host, port, network = 'ipv4'): length = len(sockaddr) + 4 # Save five bytes 'push 0' return (sockaddr, length, getattr(address_family, "name", address_family)) -def sock_match(local, remote, fam=socket.AF_UNSPEC, typ=0): +def sock_match(local: tuple | tubes.sock.sock, remote: str, fam: int =socket.AF_UNSPEC, typ: int = 0) -> Callable[[], bool]: """ Given two addresses, returns a function comparing address pairs from psutil library against these two. Useful for filtering done in :func:`pwnlib.util.proc.pidof`. """ - def sockinfos(addr, f, t): + def sockinfos(addr: dict, f: int, t: int) -> set: if not addr: return set() if f not in (socket.AF_UNSPEC, socket.AF_INET, socket.AF_INET6): @@ -260,7 +265,7 @@ def sockinfos(addr, f, t): if remote is not None: remote = sockinfos(remote, fam, typ) - def match(c): + def match(c) -> bool: # noqa: ANN001 laddrs = sockinfos(c.laddr, c.family, c.type) raddrs = sockinfos(c.raddr, c.family, c.type) if not (laddrs & local): diff --git a/pwnlib/util/packing.py b/pwnlib/util/packing.py index 3a6466ad5..ca501fcb2 100644 --- a/pwnlib/util/packing.py +++ b/pwnlib/util/packing.py @@ -29,9 +29,10 @@ >>> with context.local(endian='big'): print(repr(p(0x1ff))) b'\xff\x01' """ -import collections import struct import sys +from typing import Any, Callable, Iterable, BinaryIO +from collections.abc import Sequence import warnings from pwnlib.context import LocalNoarchContext @@ -43,7 +44,7 @@ mod = sys.modules[__name__] log = getLogger(__name__) -def pack(number, word_size = None, endianness = None, sign = None, **kwargs): +def pack(number: int, word_size: str | int = None, endianness: str = None, sign: bool = None, **kwargs: dict[str, Any]) -> bytes: r"""pack(number, word_size = None, endianness = None, sign = None, **kwargs) -> str Packs arbitrary-sized integer. @@ -159,7 +160,7 @@ def pack(number, word_size = None, endianness = None, sign = None, **kwargs): return b''.join(reversed(out)) @LocalNoarchContext -def unpack(data, word_size = None): +def unpack(data: bytes, word_size: str | int = None) -> str: r"""unpack(data, word_size = None, *, endianness = None, sign = None, **kwargs) -> int Unpacks arbitrary-sized integer. @@ -233,7 +234,7 @@ def unpack(data, word_size = None): return int(number - 2*signbit) @LocalNoarchContext -def unpack_many(data, word_size = None): +def unpack_many(data: bytes, word_size: int | str = None) -> str: r"""unpack_many(data, word_size = None, *, endianness = None, sign = None) -> int list Splits `data` into groups of ``word_size//8`` bytes and calls :func:`unpack` on each group. Returns a list of the results. @@ -265,8 +266,6 @@ def unpack_many(data, word_size = None): """ # Lookup in context if None word_size = word_size or context.word_size - endianness = context.endianness - sign = context.sign if word_size == 'all': return [unpack(data, word_size)] @@ -295,7 +294,7 @@ def unpack_many(data, word_size = None): op_verbs = {'p': 'pack', 'u': 'unpack'} -def make_single(op,size,end,sign): +def make_single(op: str, size: str, end: str, sign: str) -> tuple[str, Callable]: name = '_%s%s%s%s' % (op, size, end, sign) fmt = sizes[size] @@ -303,11 +302,11 @@ def make_single(op,size,end,sign): if fmt == '': endianess = 'big' if end == 'b' else 'little' if op == 'u': - def routine(data, stacklevel=1): + def routine(data: bytes | bytearray | str, stacklevel: int = 1) -> str: data = _need_bytes(data, stacklevel) return unpack(data, size, endianness=endianess, sign=sign == 's') else: - def routine(data, stacklevel=None): + def routine(data: int, stacklevel: int = None) -> bytes: return pack(data, size, endianness=endianess, sign=sign == 's') routine.__name__ = routine.__qualname__ = name return name, routine @@ -320,11 +319,11 @@ def routine(data, stacklevel=None): struct_op = getattr(struct.Struct(fmt), op_verbs[op]) if op == 'u': - def routine(data, stacklevel=1): + def routine(data: Sequence, stacklevel: int = 1) -> Any: data = _need_bytes(data, stacklevel) return struct_op(data)[0] else: - def routine(data, stacklevel=None): + def routine(data: Sequence, stacklevel: int = None) -> Any: return struct_op(data) routine.__name__ = routine.__qualname__ = name @@ -339,7 +338,7 @@ def routine(data, stacklevel=None): # # Make normal user-oriented packers, e.g. p8 # -def _do_packing(op, size, number, endianness=None): +def _do_packing(op: str, size: int, number: int, endianness: str = None) -> bytes: name = "%s%s" % (op,size) mod = sys.modules[__name__] @@ -357,7 +356,7 @@ def _do_packing(op, size, number, endianness=None): ("big", False): bu}[endian, signed](number, 3) @LocalNoarchContext -def p8(number, endianness = None, **kwargs): +def p8(number: int, endianness: str = None, **kwargs: dict[str, Any]) -> bytes: """p8(number, endianness, sign, ...) -> bytes Packs an 8-bit integer @@ -375,7 +374,7 @@ def p8(number, endianness = None, **kwargs): return _do_packing('p', 8, number, endianness) @LocalNoarchContext -def p16(number, endianness = None, **kwargs): +def p16(number: int, endianness: str = None, **kwargs: dict[str, Any]) -> bytes: """p16(number, endianness, sign, ...) -> bytes Packs an 16-bit integer @@ -400,7 +399,7 @@ def p16(number, endianness = None, **kwargs): return _do_packing('p', 16, number, endianness) @LocalNoarchContext -def p32(number, endianness = None, **kwargs): +def p32(number: int, endianness: str = None, **kwargs: dict[str, Any]) -> bytes: """p32(number, endianness, sign, ...) -> bytes Packs an 32-bit integer @@ -425,7 +424,7 @@ def p32(number, endianness = None, **kwargs): return _do_packing('p', 32, number, endianness) @LocalNoarchContext -def p40(number, endianness = None, **kwargs): +def p40(number: int, endianness: str = None, **kwargs: dict[str, Any]) -> bytes: """p40(number, endianness, sign, ...) -> bytes Packs an 40-bit integer @@ -450,7 +449,7 @@ def p40(number, endianness = None, **kwargs): return _do_packing('p', 40, number, endianness) @LocalNoarchContext -def p48(number, endianness = None, **kwargs): +def p48(number: int, endianness: str = None, **kwargs: dict[str, Any]) -> bytes: """p48(number, endianness, sign, ...) -> bytes Packs an 48-bit integer @@ -475,7 +474,7 @@ def p48(number, endianness = None, **kwargs): return _do_packing('p', 48, number, endianness) @LocalNoarchContext -def p56(number, endianness = None, **kwargs): +def p56(number: int, endianness: str = None, **kwargs: dict[str, Any]) -> bytes: """p56(number, endianness, sign, ...) -> bytes Packs an 56-bit integer @@ -500,7 +499,7 @@ def p56(number, endianness = None, **kwargs): return _do_packing('p', 56, number, endianness) @LocalNoarchContext -def p64(number, endianness = None, **kwargs): +def p64(number: int, endianness: str = None, **kwargs: dict[str, Any]) -> bytes: """p64(number, endianness, sign, ...) -> bytes Packs an 64-bit integer @@ -525,7 +524,7 @@ def p64(number, endianness = None, **kwargs): return _do_packing('p', 64, number, endianness) @LocalNoarchContext -def u8(data, endianness = None, **kwargs): +def u8(data: bytes, endianness: str = None, **kwargs: dict[str, Any]) -> int: """u8(data, endianness, sign, ...) -> int Unpacks an 8-bit integer @@ -543,7 +542,7 @@ def u8(data, endianness = None, **kwargs): return _do_packing('u', 8, data, endianness) @LocalNoarchContext -def u16(data, endianness = None, **kwargs): +def u16(data: bytes, endianness: str = None, **kwargs: dict[str, Any]) -> int: """u16(data, endianness, sign, ...) -> int Unpacks an 16-bit integer @@ -561,7 +560,7 @@ def u16(data, endianness = None, **kwargs): return _do_packing('u', 16, data, endianness) @LocalNoarchContext -def u32(data, endianness = None, **kwargs): +def u32(data: bytes, endianness: str = None, **kwargs: dict[str, Any]) -> int: """u32(data, endianness, sign, ...) -> int Unpacks an 32-bit integer @@ -579,7 +578,7 @@ def u32(data, endianness = None, **kwargs): return _do_packing('u', 32, data, endianness) @LocalNoarchContext -def u40(data, endianness = None, **kwargs): +def u40(data: bytes, endianness: str = None, **kwargs: dict[str, Any]) -> int: """u40(data, endianness, sign, ...) -> int Unpacks an 40-bit integer @@ -597,7 +596,7 @@ def u40(data, endianness = None, **kwargs): return _do_packing('u', 40, data, endianness) @LocalNoarchContext -def u48(data, endianness = None, **kwargs): +def u48(data: bytes, endianness: str = None, **kwargs: dict[str, Any]) -> int: """u48(data, endianness, sign, ...) -> int Unpacks an 48-bit integer @@ -615,7 +614,7 @@ def u48(data, endianness = None, **kwargs): return _do_packing('u', 48, data, endianness) @LocalNoarchContext -def u56(data, endianness = None, **kwargs): +def u56(data: bytes, endianness: str = None, **kwargs: dict[str, Any]) -> int: """u56(data, endianness, sign, ...) -> int Unpacks an 56-bit integer @@ -633,7 +632,7 @@ def u56(data, endianness = None, **kwargs): return _do_packing('u', 56, data, endianness) @LocalNoarchContext -def u64(data, endianness = None, **kwargs): +def u64(data: bytes, endianness: str = None, **kwargs: dict[str, Any]) -> int: """u64(data, endianness, sign, ...) -> int Unpacks an 64-bit integer @@ -650,7 +649,7 @@ def u64(data, endianness = None, **kwargs): """ return _do_packing('u', 64, data, endianness) -def make_packer(word_size = None, sign = None, **kwargs): +def make_packer(word_size: int = None, sign: str = None, **kwargs: dict[str, Any]) -> Callable[[int], str]: r"""make_packer(word_size = None, endianness = None, sign = None) -> number โ†’ str Creates a packer by "freezing" the given arguments. @@ -714,7 +713,7 @@ def make_packer(word_size = None, sign = None, **kwargs): return lambda number: pack(number, word_size, endianness, sign) @LocalNoarchContext -def make_unpacker(word_size = None, endianness = None, sign = None, **kwargs): +def make_unpacker(word_size: int = None, endianness: str = None, sign: str = None, **kwargs: dict[str, Any]) -> Callable[[str], int]: """make_unpacker(word_size = None, endianness = None, sign = None, **kwargs) -> str โ†’ number Creates an unpacker by "freezing" the given arguments. @@ -775,12 +774,12 @@ def make_unpacker(word_size = None, endianness = None, sign = None, **kwargs): else: return lambda number: unpack(number, word_size, endianness, sign) -def _fit(pieces, preprocessor, packer, filler, stacklevel=1): +def _fit(pieces: dict, preprocessor: Callable, packer: Callable, filler: Iterable, stacklevel: int = 1) -> tuple[iters.chain[int], bytes]: # Pulls bytes from `filler` and adds them to `pad` until it ends in `key`. # Returns the index of `key` in `pad`. pad = bytearray() - def fill(key): + def fill(key: bytes | bytearray) -> int: key = bytearray(key) offset = pad.find(key) while offset == -1: @@ -850,7 +849,7 @@ def fill(key): return filler, out_negative + out -def _flat(args, preprocessor, packer, filler, stacklevel=1): +def _flat(args: list[list | tuple | dict | bytes | str | int], preprocessor: Callable, packer: Callable, filler: Iterable, stacklevel: int = 1) -> bytes: out = [] for arg in args: @@ -886,7 +885,7 @@ def _flat(args, preprocessor, packer, filler, stacklevel=1): return b''.join(out) @LocalNoarchContext -def flat(*args, **kwargs): +def flat(*args: tuple, **kwargs: dict[str, Any]) -> str: r"""flat(\*args, preprocessor = None, length = None, filler = de_bruijn(), word_size = None, endianness = None, sign = None) -> str @@ -1058,7 +1057,7 @@ def flat(*args, **kwargs): return out -def fit(*args, **kwargs): +def fit(*args: tuple, **kwargs: dict[str, Any]) -> bytes: """Legacy alias for :func:`flat`""" kwargs['stacklevel'] = kwargs.get('stacklevel', 0) + 1 return flat(*args, **kwargs) @@ -1100,13 +1099,13 @@ def fit(*args, **kwargs): """ -def signed(integer): +def signed(integer: int) -> str: return unpack(pack(integer), signed=True) -def unsigned(integer): +def unsigned(integer: int) -> str: return unpack(pack(integer)) -def dd(dst, src, count = 0, skip = 0, seek = 0, truncate = False): +def dd(dst: BinaryIO | Sequence, src: Iterable, count: int = 0, skip: int = 0, seek: int = 0, truncate: bool = False) -> BinaryIO | Sequence: r"""dd(dst, src, count = 0, skip = 0, seek = 0, truncate = False) -> dst Inspired by the command line tool ``dd``, this function copies `count` byte @@ -1291,7 +1290,7 @@ def dd(dst, src, count = 0, skip = 0, seek = 0, truncate = False): return dst -def _need_bytes(s, level=1, min_wrong=0): +def _need_bytes(s: Sequence, level: int = 1, min_wrong: int = 0) -> bytes: if isinstance(s, (bytes, bytearray)): return s # already bytes @@ -1313,7 +1312,7 @@ def _need_bytes(s, level=1, min_wrong=0): BytesWarning, level + 2) return s.encode(encoding, errors) -def _need_text(s, level=1): +def _need_text(s: str | bytes | bytearray, level: int = 1) -> str: if isinstance(s, str): return s # already text @@ -1335,7 +1334,7 @@ def _need_text(s, level=1): BytesWarning, level + 2) return s.decode(encoding, errors) -def _encode(s): +def _encode(s: Sequence) -> bytes: if isinstance(s, (bytes, bytearray)): return s # already bytes @@ -1346,7 +1345,7 @@ def _encode(s): return s.encode('utf-8', 'surrogateescape') return s.encode(context.encoding) -def _decode(b): +def _decode(b: str | bytes) -> str: if isinstance(b, str): return b # already text diff --git a/pwnlib/util/proc.py b/pwnlib/util/proc.py index 65bdb76ee..24a7dd534 100644 --- a/pwnlib/util/proc.py +++ b/pwnlib/util/proc.py @@ -2,6 +2,7 @@ import socket import sys import time +from typing import Any import psutil @@ -14,7 +15,7 @@ all_pids = psutil.pids -def pidof(target): +def pidof(target: tubes.ssh.ssh_channel | tubes.sock.sock | tuple | tubes.process.process) -> list[int]: """pidof(target) -> int list Get PID(s) of `target`. The returned PID(s) depends on the type of `target`: @@ -62,7 +63,7 @@ def pidof(target): else: return pid_by_name(target) -def pid_by_name(name): +def pid_by_name(name: str) -> list[int]: """pid_by_name(name) -> int list Arguments: @@ -76,7 +77,7 @@ def pid_by_name(name): >>> os.getpid() in pid_by_name(name(os.getpid())) True """ - def match(p): + def match(p: psutil.Process) -> bool: if p.status() == 'zombie': return False if p.name() == name: @@ -96,7 +97,7 @@ def match(p): return [p.pid for p in processes] -def name(pid): +def name(pid: int) -> str: """name(pid) -> str Arguments: @@ -113,7 +114,7 @@ def name(pid): """ return psutil.Process(pid).name() -def parent(pid): +def parent(pid: int) -> int: """parent(pid) -> int Arguments: @@ -128,7 +129,7 @@ def parent(pid): except Exception: return 0 -def children(ppid): +def children(ppid: int) -> list[int]: """children(ppid) -> int list Arguments: @@ -139,7 +140,7 @@ def children(ppid): """ return [p.pid for p in psutil.Process(ppid).children()] -def ancestors(pid): +def ancestors(pid: int) -> list[int]: """ancestors(pid) -> int list Arguments: @@ -159,7 +160,7 @@ def ancestors(pid): pid = parent(pid) return pids -def descendants(pid): +def descendants(pid: int) -> dict[int, dict[int, int]]: #TODO: verify """descendants(pid) -> dict Arguments: @@ -177,17 +178,17 @@ def descendants(pid): this_pid = pid allpids = all_pids() ppids = {} - def _parent(pid): + def _parent(pid: int) -> dict[int, int]: if pid not in ppids: ppids[pid] = parent(pid) return ppids[pid] - def _children(ppid): + def _children(ppid: int) -> list[int]: return [pid for pid in allpids if _parent(pid) == ppid] - def _loop(ppid): + def _loop(ppid: int) -> dict[int, dict[int, int]]: return {pid: _loop(pid) for pid in _children(ppid)} return _loop(pid) -def exe(pid): +def exe(pid: int) -> str: """exe(pid) -> str Arguments: @@ -203,7 +204,7 @@ def exe(pid): """ return psutil.Process(pid).exe() -def cwd(pid): +def cwd(pid: int) -> str: """cwd(pid) -> str Arguments: @@ -220,7 +221,7 @@ def cwd(pid): """ return psutil.Process(pid).cwd() -def cmdline(pid): +def cmdline(pid: int) -> list[str]: """cmdline(pid) -> str list Arguments: @@ -236,7 +237,7 @@ def cmdline(pid): """ return psutil.Process(pid).cmdline() -def memory_maps(pid): +def memory_maps(pid: int) -> list[tuple[str, str]]: """memory_maps(pid) -> list Arguments: @@ -254,7 +255,7 @@ def memory_maps(pid): """ return psutil.Process(pid).memory_maps(grouped=False) -def stat(pid): +def stat(pid: int) -> list[str]: """stat(pid) -> str list Arguments: @@ -276,7 +277,7 @@ def stat(pid): name = s[i+1:j] return s[:i].split() + [name] + s[j+1:].split() -def starttime(pid): +def starttime(pid: int) -> float: """starttime(pid) -> float Arguments: @@ -292,7 +293,7 @@ def starttime(pid): """ return psutil.Process(pid).create_time() - psutil.boot_time() -def status(pid): +def status(pid: int) -> dict[str, str]: """status(pid) -> dict Get the status of a process. @@ -320,11 +321,11 @@ def status(pid): raise return out -def _tracer_windows(pid): +def _tracer_windows(pid: int) -> int: import ctypes from ctypes import wintypes - def _check_bool(result, func, args): + def _check_bool(result: Any, func: callable, args: list) -> list: if not result: raise ctypes.WinError(ctypes.get_last_error()) return args @@ -356,7 +357,7 @@ def _check_bool(result, func, args): return ret -def tracer(pid): +def tracer(pid: int) -> int | None: """tracer(pid) -> int Arguments: @@ -376,7 +377,7 @@ def tracer(pid): tpid = int(status(pid)['TracerPid']) return tpid if tpid > 0 else None -def state(pid): +def state(pid: int) -> str: """state(pid) -> str Arguments: @@ -392,7 +393,7 @@ def state(pid): """ return status(pid)['State'] -def wait_for_debugger(pid, debugger_pid=None): +def wait_for_debugger(pid: int, debugger_pid: int | None = None) -> int | None: """wait_for_debugger(pid, debugger_pid=None) -> None Sleeps until the process with PID `pid` is being traced. diff --git a/pwnlib/util/safeeval.py b/pwnlib/util/safeeval.py index f8eee648c..4be66bb80 100644 --- a/pwnlib/util/safeeval.py +++ b/pwnlib/util/safeeval.py @@ -1,3 +1,6 @@ +from types import CodeType +from typing import Any + _const_codes = [ 'POP_TOP','ROT_TWO','ROT_THREE','ROT_FOUR','DUP_TOP', 'BUILD_LIST','BUILD_MAP', 'MAP_ADD', 'BUILD_TUPLE','BUILD_SET', @@ -19,7 +22,7 @@ _values_codes = _expr_codes + ['LOAD_NAME'] -def _get_opcodes(codeobj): +def _get_opcodes(codeobj: CodeType) -> list[int]: """_get_opcodes(codeobj) -> [opcodes] Extract the actual opcodes as a list from a code object @@ -31,7 +34,7 @@ def _get_opcodes(codeobj): import dis return [ins.opcode for ins in dis.get_instructions(codeobj)] -def test_expr(expr, allowed_codes): +def test_expr(expr: str, allowed_codes: list[str]) -> CodeType: """test_expr(expr, allowed_codes) -> codeobj Test that the expression contains only the listed opcodes. @@ -50,7 +53,7 @@ def test_expr(expr, allowed_codes): raise ValueError("opcode %s not allowed" % dis.opname[code]) return c -def const(expr): +def const(expr: str) -> Any: """const(expression) -> value Safe Python constant evaluation @@ -74,7 +77,7 @@ def const(expr): c = test_expr(expr, _const_codes) return eval(c) -def expr(expr): +def expr(expr: str) -> Any: """expr(expression) -> value Safe Python expression evaluation @@ -98,7 +101,7 @@ def expr(expr): c = test_expr(expr, _expr_codes) return eval(c) -def values(expr, env): +def values(expr: str, env: dict[str, Any]) -> Any: """values(expression, dict) -> value Safe Python expression evaluation diff --git a/pwnlib/util/sh_string.py b/pwnlib/util/sh_string.py index 34ca4c195..a733491e3 100644 --- a/pwnlib/util/sh_string.py +++ b/pwnlib/util/sh_string.py @@ -239,6 +239,7 @@ """ import string import subprocess +from typing import Callable from pwnlib.context import context from pwnlib.log import getLogger @@ -248,7 +249,7 @@ log = getLogger(__name__) -def test_all(): +def test_all() -> None: test('a') ## test('ab') ## test('a b') ## @@ -269,10 +270,10 @@ def test_all(): everything_2 = b''.join(bytes([c,c]) for c in range(1,256)) ## test(everything_2) - test(randoms(1000, everything_1)) + test(fiddling.randoms(1000, everything_1)) -def test(original): +def test(original: bytes) -> None: r"""Tests the output provided by a shell interpreting a string .. doctest:: @@ -358,7 +359,7 @@ def test(original): # '\\': '"\\\\\\\\"' } -def sh_string(s): +def sh_string(s: str) -> bytes | str: r"""Outputs a string in a format that will be understood by /bin/sh. If the string does not contain any bad characters, it will simply be @@ -437,7 +438,7 @@ def sh_string(s): quoted_string = quoted_string.encode('latin1') return quoted_string -def sh_prepare(variables, export = False): +def sh_prepare(variables: dict, export: bool = False) -> bytes: r"""Outputs a posix compliant shell command that will put the data specified by the dictionary into the environment. @@ -488,7 +489,7 @@ def sh_prepare(variables, export = False): return b';'.join(out) -def sh_command_with(f, *args): +def sh_command_with(f: Callable, *args: tuple) -> str: r"""sh_command_with(f, arg0, ..., argN) -> command Returns a command create by evaluating `f(new_arg0, ..., new_argN)` diff --git a/pwnlib/util/splash.py b/pwnlib/util/splash.py index 43434d78d..0e36fe4ab 100644 --- a/pwnlib/util/splash.py +++ b/pwnlib/util/splash.py @@ -23,13 +23,13 @@ .:*~*:._.:*~*:._.:*~*:._.:*~*:._.:*~*:._.:*~*:._.:*~*:._.:*~*:._.:*~*:. ''' -def splash(): +def splash() -> None: """Put this at the beginning of your exploit to create the illusion that your sploit is enterprisey and top notch quality""" - def updater(): + def updater() -> None: - colors = [ + colors: list[int] = [ text.blue , text.bold_blue , text.magenta, text.bold_magenta, text.red , text.bold_red , @@ -37,7 +37,7 @@ def updater(): text.green , text.bold_green , text.cyan , text.bold_cyan , ] - def getcolor(n): + def getcolor(n: int) -> int: return colors[(n // 4) % len(colors)] lines = [' ' + line + '\n' for line in _banner.strip('\n').split('\n')] diff --git a/pwnlib/util/web.py b/pwnlib/util/web.py index 4faa0aa7b..34fd3f4ba 100644 --- a/pwnlib/util/web.py +++ b/pwnlib/util/web.py @@ -1,5 +1,6 @@ import os import tempfile +from typing import Any from pwnlib.log import getLogger from pwnlib.tubes.buffer import Buffer @@ -7,7 +8,7 @@ log = getLogger(__name__) -def wget(url, save=None, timeout=5, **kwargs): +def wget(url: str, save: str | bool=None, timeout: int=5, **kwargs: dict[str, Any]) -> None | bytes: r"""wget(url, save=None, timeout=5) -> str Downloads a file via HTTP/HTTPS.