diff --git a/.github/instructions/style-guide.instructions.md b/.github/instructions/style-guide.instructions.md index b79834867b..538e9672f7 100644 --- a/.github/instructions/style-guide.instructions.md +++ b/.github/instructions/style-guide.instructions.md @@ -196,28 +196,43 @@ def calculate_score( The PyRIT docs build uses **MyST** (Markdown-flavoured), not reStructuredText. Do **not** use reST cross-reference roles in docstrings or module comments — -they render as raw text under MyST and are inconsistent with the rest of the -codebase, which uses plain double-backtick code spans for symbol names. +they render as raw text under MyST. A pre-commit hook +(`check_no_rest_roles`) blocks new ones from landing. + +Use plain double-backticks for symbol references. The API page generator +(`build_scripts/gen_api_md.py`) automatically rewrites known PyRIT symbol +names into MyST cross-reference links at build time, so you get clickable +navigation in the rendered docs without any extra markup. ```python -# WRONG — reST roles render as literal `:class:\`SeedPrompt\`` under MyST +# WRONG — reST roles render as literal `:class:\`SeedPrompt\`` under MyST, +# and the pre-commit guard will reject them """Returns a :class:`SeedPrompt` instance.""" """Delegate to :func:`download_files_async` (deprecated alias).""" """See :meth:`PromptTarget.apply_capabilities` for details.""" -# CORRECT — plain double-backtick code span (matches existing convention) +# CORRECT — plain double-backticks; gen_api_md.py auto-links these """Returns a ``SeedPrompt`` instance.""" """Delegate to ``download_files_async`` (deprecated alias).""" """See ``PromptTarget.apply_capabilities`` for details.""" ``` -Roles to avoid include `:class:`, `:func:`, `:meth:`, `:mod:`, `:attr:`, -`:data:`, `:exc:`, `:obj:`, `:ref:`, and any `:py:*:` variants -(e.g. `:py:class:`, `:py:func:`). +The auto-linker resolves: + +- bare class/function names (`` ``SeedPrompt`` ``) +- `Class.method` references (`` ``PromptTarget.apply_capabilities`` ``) +- fully-qualified paths (`` ``pyrit.models.SeedPrompt`` ``) +- bare method names when the docstring is on the owning class + (`` ``send_prompt_async`` `` inside `PromptTarget`) + +Ambiguous short names (e.g. two unrelated classes both called `Scorer`) +are left as plain code-spans; spell out the FQN when you need a stable +cross-reference. Unknown names also stay as plain code-spans, so +docstrings remain safe to write without consulting the symbol index. -If you genuinely need a Sphinx cross-reference (rare in PyRIT — most -docstrings just name the symbol in backticks), use the MyST role syntax -`` {class}`Name` `` instead. The default, though, is plain double-backticks. +If you need an explicit MyST link in markdown documentation, use the +standard syntax `` [`Name`](#api-pyrit_module-Name) `` — but inside +Python docstrings this should be rare; plain backticks are the default. ### Class-Level Constants - Define constants as class attributes, not module-level diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 896b51fe71..02e7425515 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,6 +30,11 @@ repos: files: ^(doc/.*\.(py|ipynb|md)|doc/myst\.yml)$ pass_filenames: false additional_dependencies: ['pyyaml'] + - id: check-no-rest-roles + name: Reject Sphinx reST cross-reference roles + entry: python ./build_scripts/check_no_rest_roles.py + language: python + files: ^pyrit/.*\.py$ - id: enforce_alembic_revision_immutability name: Enforce Alembic Revision Immutability entry: python ./build_scripts/enforce_alembic_revision_immutability.py diff --git a/build_scripts/check_no_rest_roles.py b/build_scripts/check_no_rest_roles.py new file mode 100644 index 0000000000..a0c1f38cbd --- /dev/null +++ b/build_scripts/check_no_rest_roles.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Pre-commit guard against Sphinx reST cross-reference roles in source. + +PyRIT docs render docstrings through MyST (jupyter-book 2), not Sphinx, so +reST roles like ``:class:`Foo``` show up as raw literal text in the built +site. The standing convention (style-guide.instructions.md, PR #1782) is to +use plain double-backticks; ``build_scripts/gen_api_md.py`` then auto-links +known PyRIT symbols at render time. + +This hook flags any newly introduced reST role inside ``pyrit/`` so it can +be replaced before landing. Run it manually with:: + + uv run python build_scripts/check_no_rest_roles.py + +or rely on the ``check-no-rest-roles`` pre-commit hook in +``.pre-commit-config.yaml``. +""" + +from __future__ import annotations + +import re +import sys +from pathlib import Path + +# Roles flagged by this guard. Mirrors the list in the style guide. The +# pattern matches the leading colon, role name, and the opening backtick of +# the role argument (e.g. ``:class:`Foo```), so backticked code spans that +# happen to start with a colon character are not caught. +_REST_ROLE_RE = re.compile(r":(?:class|func|meth|mod|attr|data|exc|obj|ref|py:[a-z]+):`") + + +def _check_file(path: Path) -> list[tuple[int, str]]: + findings: list[tuple[int, str]] = [] + try: + text = path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError): + return findings + for lineno, line in enumerate(text.splitlines(), start=1): + if _REST_ROLE_RE.search(line): + findings.append((lineno, line.rstrip())) + return findings + + +def main(argv: list[str] | None = None) -> int: + args = list(argv if argv is not None else sys.argv[1:]) + if not args: + return 0 + + failures: list[tuple[Path, list[tuple[int, str]]]] = [] + for raw in args: + path = Path(raw) + if path.suffix != ".py": + continue + findings = _check_file(path) + if findings: + failures.append((path, findings)) + + if not failures: + return 0 + + print("\nreST cross-reference roles are not allowed in PyRIT source.") + print("PyRIT renders docstrings with MyST, not Sphinx — these roles show") + print("up as raw literal text in the built docs.\n") + print("Replace ``:class:`Foo``` / ``:func:`bar``` / ``:meth:`Baz.do``` etc.") + print("with plain double-backticks (``Foo``). build_scripts/gen_api_md.py") + print("auto-links known PyRIT symbols at render time.\n") + for path, findings in failures: + for lineno, snippet in findings: + print(f" {path}:{lineno}: {snippet}") + print() + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/build_scripts/gen_api_md.py b/build_scripts/gen_api_md.py index da128dc2da..a3da0fedaa 100644 --- a/build_scripts/gen_api_md.py +++ b/build_scripts/gen_api_md.py @@ -17,7 +17,9 @@ """ import json +import re import sys +from dataclasses import dataclass from pathlib import Path # Import sibling script for post-generation TOC validation. @@ -33,6 +35,45 @@ } +@dataclass(frozen=True) +class SymbolEntry: + """A resolved API symbol that can be cross-referenced from a docstring.""" + + module: str # dotted module path, e.g. "pyrit.prompt_target" + kind: str # "class" | "function" | "method" + name: str # short name (last segment) + qualname: str # "PromptTarget" or "PromptTarget.send_prompt_async" + anchor: str # MyST label, e.g. "api-pyrit_prompt_target-PromptTarget" + + +# Backtick code spans that look like Python identifiers (with optional +# dotted paths) — candidates for symbol cross-reference rewriting. Matches +# either `name` or ``name``. The leading negative lookbehind prevents +# touching spans inside an already-rendered MyST link such as +# ``[`Name`](#anchor)`` and also prevents the single-backtick branch from +# matching the inner portion of a ``\u0060\u0060Name\u0060\u0060`` pair. +# A leading tilde or dot is tolerated because reST cross-reference syntax +# like ``:class:`~pyrit.foo.Bar``` may have leaked through earlier cleanups. +_SYMBOL_REF_RE = re.compile(r"(? str: + """Convert a dotted module path to a MyST-label-safe slug.""" + return module.replace(".", "_") + + +def _class_anchor(module: str, class_name: str) -> str: + return f"api-{_module_slug(module)}-{class_name}" + + +def _function_anchor(module: str, func_name: str) -> str: + return f"api-{_module_slug(module)}-{func_name}" + + +def _method_anchor(module: str, class_name: str, method_name: str) -> str: + return f"api-{_module_slug(module)}-{class_name}-{method_name}" + + def render_params(params: list[dict]) -> str: """Render parameter list as a markdown table.""" if not params: @@ -122,7 +163,233 @@ def _escape_docstring_examples(text: str) -> str: return "\n".join(result) -def render_function(func: dict, heading_level: str = "###") -> str: +def _build_symbol_index(modules: list[dict]) -> dict[str, list[SymbolEntry]]: + """Build a lookup of every API symbol that the rewriter can target. + + The returned dict is keyed by both the short name (e.g. ``"PromptTarget"``, + ``"send_prompt_async"``) and several qualified forms + (``"PromptTarget.send_prompt_async"``, ``"pyrit.prompt_target.PromptTarget"``, + ``"pyrit.prompt_target.PromptTarget.send_prompt_async"``). Each entry holds + the module, kind, and final anchor that ``_rewrite_symbol_refs`` will link + to. + + Multiple entries under the same key indicate an ambiguous reference; the + rewriter intentionally skips those so we don't pick a wrong target. + """ + index: dict[str, list[SymbolEntry]] = {} + + def _add(key: str, entry: SymbolEntry) -> None: + index.setdefault(key, []).append(entry) + + for module in modules: + mod_name = module.get("name", "") + for member in module.get("members", []): + kind = member.get("kind", "") + name = member.get("name", "") + if not name or name.startswith("_"): + continue + if kind == "class": + entry = SymbolEntry( + module=mod_name, + kind="class", + name=name, + qualname=name, + anchor=_class_anchor(mod_name, name), + ) + _add(name, entry) + _add(f"{mod_name}.{name}", entry) + for method in member.get("methods", []) or []: + mname = method.get("name", "") + if not mname or mname.startswith("_"): + continue + m_entry = SymbolEntry( + module=mod_name, + kind="method", + name=mname, + qualname=f"{name}.{mname}", + anchor=_method_anchor(mod_name, name, mname), + ) + _add(mname, m_entry) + _add(f"{name}.{mname}", m_entry) + _add(f"{mod_name}.{name}.{mname}", m_entry) + elif kind == "function": + entry = SymbolEntry( + module=mod_name, + kind="function", + name=name, + qualname=name, + anchor=_function_anchor(mod_name, name), + ) + _add(name, entry) + _add(f"{mod_name}.{name}", entry) + return index + + +def _resolve_symbol(raw: str, index: dict[str, list[SymbolEntry]], current_class: str | None) -> SymbolEntry | None: + """Return the cross-reference target for a bare backtick-quoted symbol. + + ``raw`` is the contents between backticks — already stripped of surrounding + syntax. The lookup is conservative: if more than one symbol matches, we + return ``None`` to leave the original markup untouched. Trailing tilde + prefixes (``~pyrit.foo.Bar``) and leading dots are tolerated because they + occasionally survive Sphinx-style imports. + """ + cleaned = raw.lstrip("~").lstrip(".") + if not cleaned: + return None + + # Try the literal lookup first (handles FQN and Class.method forms). + entries = index.get(cleaned) + if entries and len(entries) == 1: + return entries[0] + + # When inside a class context, a bare method name should resolve to that + # class's method even if other classes share the same method name. + if current_class and "." not in cleaned: + scoped = index.get(f"{current_class}.{cleaned}") + if scoped and len(scoped) == 1: + return scoped[0] + + return None + + +def _rewrite_symbol_refs( + text: str, + index: dict[str, list[SymbolEntry]], + *, + current_class: str | None = None, +) -> str: + """Convert ``Name`` / ``Class.method`` backtick spans to MyST links. + + Fenced code blocks are preserved verbatim so doctest examples and Python + snippets don't get mangled. Within prose, each backtick code span is + looked up against ``index``; matches become ``[`Name`](#anchor)`` links, + and everything else is left unchanged. + """ + if not text: + return text + + lines = text.split("\n") + output: list[str] = [] + in_fence = False + fence_marker: str | None = None + + for line in lines: + stripped = line.lstrip() + if not in_fence and stripped.startswith(("```", "~~~")): + in_fence = True + fence_marker = stripped[:3] + output.append(line) + continue + if in_fence: + output.append(line) + if stripped.startswith(fence_marker or "```"): + in_fence = False + fence_marker = None + continue + + def _sub(match: re.Match[str]) -> str: + full = match.group(1) + symbol = match.group(2) or match.group(3) or "" + entry = _resolve_symbol(symbol, index, current_class) + if entry is None: + return full + return f"[{full}](#{entry.anchor})" + + output.append(_SYMBOL_REF_RE.sub(_sub, line)) + + return "\n".join(output) + + +def _rewrite_param_table(params: list[dict], index: dict[str, list[SymbolEntry]], current_class: str | None) -> None: + """Run the symbol rewriter over parameter descriptions in-place.""" + for p in params: + if p.get("desc"): + p["desc"] = _rewrite_symbol_refs(p["desc"], index, current_class=current_class) + + +def _format_bases(bases: list[str], symbol_index: dict[str, list[SymbolEntry]] | None) -> str: + """Render each base class as an individually-linkable code span. + + Each base is wrapped in single backticks and run through the symbol + rewriter separately so that known PyRIT bases become MyST cross-reference + links while external bases (e.g. ``str``, ``Enum``) stay as plain code + spans. The comma-joined output keeps the rendered ``Bases:`` line readable + even when only some bases resolve. + """ + if not bases: + return "" + if symbol_index is None: + return ", ".join(f"`{b}`" for b in bases if b) + return ", ".join(_rewrite_symbol_refs(f"`{b}`", symbol_index) for b in bases if b) + + +def _format_reexport_alias( + mod_name: str, + name: str, + symbol_index: dict[str, list[SymbolEntry]] | None, +) -> str: + """Render a re-export alias name as a MyST link when unambiguous. + + Aliases usually live on the current module, so the module-qualified path + is tried first. If that lookup is unambiguous we link directly to it; + otherwise we fall back to the regular short-name rewriter so unresolvable + aliases get the same plain code-span treatment as the rest of the docs. + """ + if not name: + return "" + if symbol_index is None: + return f"`{name}`" + fqn = f"{mod_name}.{name}" if mod_name else name + entries = symbol_index.get(fqn) + if entries and len(entries) == 1: + return f"[`{name}`](#{entries[0].anchor})" + return _rewrite_symbol_refs(f"`{name}`", symbol_index) + + +def _format_reexport_target( + target: str, + symbol_index: dict[str, list[SymbolEntry]] | None, +) -> str: + """Render a re-export target FQN as a MyST link when it resolves.""" + if not target: + return "" + if symbol_index is None: + return f"`{target}`" + return _rewrite_symbol_refs(f"`{target}`", symbol_index) + + +def _rewrite_returns_or_raises( + items: list[dict], index: dict[str, list[SymbolEntry]], current_class: str | None +) -> None: + """Run the symbol rewriter over returns/raises description text in-place.""" + for item in items: + if item.get("desc"): + item["desc"] = _rewrite_symbol_refs(item["desc"], index, current_class=current_class) + + +def _process_docstring_text( + text: str | None, + symbol_index: dict[str, list[SymbolEntry]] | None, + current_class: str | None, +) -> str | None: + """Apply doctest-fence wrapping then symbol cross-reference rewriting.""" + if not text: + return text + escaped = _escape_docstring_examples(text) + if symbol_index is None: + return escaped + return _rewrite_symbol_refs(escaped, symbol_index, current_class=current_class) + + +def render_function( + func: dict, + *, + heading_level: str = "###", + module: str, + class_name: str | None = None, + symbol_index: dict[str, list[SymbolEntry]] | None = None, +) -> str: """Render a function as markdown.""" name = func["name"] is_async = func.get("is_async", False) @@ -131,54 +398,88 @@ def render_function(func: dict, heading_level: str = "###") -> str: ret = func.get("returns_annotation", "") ret_str = f" → {ret}" if ret else "" - # Heading shows just the name; full signature in a code block below - parts = [f"{heading_level} `{prefix}{name}`\n"] + anchor = _method_anchor(module, class_name, name) if class_name else _function_anchor(module, name) + + # Anchor label precedes the heading so MyST cross-refs can target it. + parts = [f"({anchor})=", f"{heading_level} `{prefix}{name}`\n"] parts.append(f"```python\n{prefix}{name}{sig}{ret_str}\n```\n") ds = func.get("docstring", {}) if ds: - if ds.get("text"): - parts.append(_escape_docstring_examples(ds["text"]) + "\n") - params_table = render_params(ds.get("params", [])) + text = _process_docstring_text(ds.get("text"), symbol_index, current_class=class_name) + if text: + parts.append(text + "\n") + params = list(ds.get("params", [])) + if params and symbol_index is not None: + params = [dict(p) for p in params] + _rewrite_param_table(params, symbol_index, class_name) + params_table = render_params(params) if params_table: parts.append(params_table + "\n") - returns = render_returns(ds.get("returns", [])) - if returns: - parts.append(returns + "\n") - raises = render_raises(ds.get("raises", [])) - if raises: - parts.append(raises + "\n") + returns = list(ds.get("returns", [])) + if returns and symbol_index is not None: + returns = [dict(r) for r in returns] + _rewrite_returns_or_raises(returns, symbol_index, class_name) + returns_md = render_returns(returns) + if returns_md: + parts.append(returns_md + "\n") + raises = list(ds.get("raises", [])) + if raises and symbol_index is not None: + raises = [dict(r) for r in raises] + _rewrite_returns_or_raises(raises, symbol_index, class_name) + raises_md = render_raises(raises) + if raises_md: + parts.append(raises_md + "\n") return "\n".join(parts) -def render_class(cls: dict) -> str: +def render_class( + cls: dict, + *, + module: str, + symbol_index: dict[str, list[SymbolEntry]] | None = None, +) -> str: """Render a class as markdown.""" name = cls["name"] bases = cls.get("bases", []) - bases_str = f"({', '.join(bases)})" if bases else "" - parts = [f"## `{name}`\n"] - if bases_str: - parts.append(f"Bases: `{bases_str[1:-1]}`\n") + anchor = _class_anchor(module, name) + parts = [f"({anchor})=", f"## `{name}`\n"] + bases_md = _format_bases(bases, symbol_index) + if bases_md: + parts.append(f"Bases: {bases_md}\n") ds = cls.get("docstring", {}) - if ds and ds.get("text"): - parts.append(_escape_docstring_examples(ds["text"]) + "\n") + text = _process_docstring_text(ds.get("text") if ds else None, symbol_index, current_class=name) + if text: + parts.append(text + "\n") # __init__ init = cls.get("init") if init: init_ds = init.get("docstring", {}) if init_ds and init_ds.get("params"): + init_params = [dict(p) for p in init_ds["params"]] + if symbol_index is not None: + _rewrite_param_table(init_params, symbol_index, name) parts.append("**Constructor Parameters:**\n") - parts.append(render_params(init_ds["params"]) + "\n") + parts.append(render_params(init_params) + "\n") # Methods methods = cls.get("methods", []) if methods: parts.append("**Methods:**\n") - parts.extend(render_function(m, heading_level="####") for m in methods) + parts.extend( + render_function( + m, + heading_level="####", + module=module, + class_name=name, + symbol_index=symbol_index, + ) + for m in methods + ) return "\n".join(parts) @@ -193,20 +494,27 @@ def render_alias(alias: dict) -> str: return "\n".join(parts) -def render_module(data: dict) -> str: +def render_module( + data: dict, + *, + symbol_index: dict[str, list[SymbolEntry]] | None = None, +) -> str: """Render a full module page.""" mod_name = data["name"] short_name = mod_name.rsplit(".", 1)[-1] + mod_label = f"api-{_module_slug(mod_name)}" parts = [ "---", + f"label: {mod_label}", f"short_title: {short_name}", "---\n", f"# {mod_name}\n", ] ds = data.get("docstring", {}) - if ds and ds.get("text"): - parts.append(ds["text"] + "\n") + text = _process_docstring_text(ds.get("text") if ds else None, symbol_index, current_class=None) + if text: + parts.append(text + "\n") members = data.get("members", []) @@ -216,15 +524,19 @@ def render_module(data: dict) -> str: if functions: parts.append("## Functions\n") - parts.extend(render_function(f) for f in functions) + parts.extend(render_function(f, module=mod_name, symbol_index=symbol_index) for f in functions) - parts.extend(render_class(cls) for cls in classes) + parts.extend(render_class(cls, module=mod_name, symbol_index=symbol_index) for cls in classes) if aliases: parts.append("## Re-exports\n") for a in aliases: - target = a.get("target", "") - parts.append(f"- `{a['name']}` → `{target}`\n") + name_md = _format_reexport_alias(mod_name, a.get("name", ""), symbol_index) + target_md = _format_reexport_target(a.get("target", ""), symbol_index) + if target_md: + parts.append(f"- {name_md} → {target_md}\n") + else: + parts.append(f"- {name_md}\n") return "\n".join(parts) @@ -368,12 +680,16 @@ def main() -> None: _build_definition_index(data, definition_index, name_to_modules) _resolve_aliases(modules, definition_index, name_to_modules) + # Build a symbol index over the post-resolution module tree so the + # docstring rewriter can turn backticked names into MyST cross-references. + symbol_index = _build_symbol_index(modules) + # Generate per-module pages for data in modules: mod_name = data["name"] slug = mod_name.replace(".", "_") md_path = API_MD_DIR / f"{slug}.md" - content = render_module(data) + content = render_module(data, symbol_index=symbol_index) members = data.get("members", []) rendered_count = sum(1 for m in members if m.get("kind") in ("class", "function")) md_path.write_text(content, encoding="utf-8") @@ -386,10 +702,18 @@ def main() -> None: members = data.get("members", []) slug = mod_name.replace(".", "_") - classes = [f"`{m['name']}`" for m in members if m.get("kind") == "class"] - functions = [f"`{m['name']}()`" for m in members if m.get("kind") == "function"] - rendered_count = len(classes) + len(functions) - preview_items = (classes + functions)[:8] + # Link each class/function in the preview directly to its anchor so the + # index page is a fast jumping-off point. + class_links = [ + f"[`{m['name']}`](#{_class_anchor(mod_name, m['name'])})" for m in members if m.get("kind") == "class" + ] + function_links = [ + f"[`{m['name']}()`](#{_function_anchor(mod_name, m['name'])})" + for m in members + if m.get("kind") == "function" + ] + rendered_count = len(class_links) + len(function_links) + preview_items = (class_links + function_links)[:8] preview = ", ".join(preview_items) if rendered_count > len(preview_items): preview += f" ... ({rendered_count} total)" diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index 791634d71e..e2652b3be6 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -292,7 +292,7 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: ``--max-turns 7``) don't cause an error before we've had a chance to fetch the scenario's declared parameters from the server. The unknown leftovers are stashed on the returned Namespace as ``_unknown_args`` - so :func:`_reparse_with_scenario_params` can detect truly unknown flags + so ``_reparse_with_scenario_params`` can detect truly unknown flags when no scenario was specified. Args: @@ -470,7 +470,7 @@ def _reparse_with_scenario_params( Re-parse the original args with scenario-declared flags added to the base parser. The original argument list is read from ``parsed_args._raw_args`` (populated - by :func:`parse_args`). If no scenario-declared parameters are supplied but + by ``parse_args``). If no scenario-declared parameters are supplied but pass 1 left unknown args behind, surface the error now via strict re-parse. Returns: diff --git a/pyrit/common/cli_helpers.py b/pyrit/common/cli_helpers.py index ec5d9674f9..527c3a5a11 100644 --- a/pyrit/common/cli_helpers.py +++ b/pyrit/common/cli_helpers.py @@ -45,10 +45,10 @@ def validate_log_level(*, log_level: str) -> int: def validate_log_level_argparse(value: Any) -> int: """ - Argparse-compatible wrapper around :func:`validate_log_level`. + Argparse-compatible wrapper around ``validate_log_level``. Adapts the keyword-only validator to argparse's positional ``type=`` calling - convention and converts ``ValueError`` to :class:`argparse.ArgumentTypeError`. + convention and converts ``ValueError`` to ``argparse.ArgumentTypeError``. Args: value: Log level string supplied by argparse. diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 55edd0b68a..eb7d021ab0 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -650,8 +650,8 @@ async def _score_response_async(self, *, response: Message, objective: str) -> N used by the TAP algorithm to decide which branches to explore further. Blocked or errored responses are scored via the scorer's unified default behavior: - :class:`~pyrit.score.true_false.true_false_scorer.TrueFalseScorer` returns - ``Score(False)`` and :class:`~pyrit.score.float_scale.float_scale_scorer.FloatScaleScorer` + ``TrueFalseScorer`` returns + ``Score(False)`` and ``FloatScaleScorer`` returns ``Score(0.0)`` whenever no supported pieces remain after validator filtering (the normal outcome for a blocked piece). This keeps blocked branches at the bottom of the priority queue without needing attack-level error mapping. @@ -1338,7 +1338,7 @@ def __init__( Note: Blocked or errored target responses (e.g. content filter triggers from image generation targets) are scored ``0.0`` via the unified - :class:`~pyrit.score.float_scale.float_scale_scorer.FloatScaleScorer` default, + ``FloatScaleScorer`` default, which prevents premature pruning without any attack-level error mapping. To score partial content from blocked responses, set ``score_blocked_content=True`` on the objective scorer (requires diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/conversation_reference.py index 33d5e2d880..0915a045c4 100644 --- a/pyrit/models/conversation_reference.py +++ b/pyrit/models/conversation_reference.py @@ -57,7 +57,7 @@ def to_dict(self) -> dict[str, str | None]: Serialize to a JSON-compatible dictionary. .. deprecated:: - Use :meth:`model_dump` with ``mode="json"`` instead. This method + Use ``model_dump`` with ``mode="json"`` instead. This method will be removed in version 0.16.0. Returns: @@ -76,7 +76,7 @@ def from_dict(cls, data: dict[str, str | None]) -> ConversationReference: Reconstruct a ConversationReference from a dictionary. .. deprecated:: - Use :meth:`model_validate` instead. This method will be removed + Use ``model_validate`` instead. This method will be removed in version 0.16.0. Args: diff --git a/pyrit/models/retry_event.py b/pyrit/models/retry_event.py index 79bb2bbb6c..46a6e79fcf 100644 --- a/pyrit/models/retry_event.py +++ b/pyrit/models/retry_event.py @@ -38,7 +38,7 @@ def to_dict(self) -> dict: Serialize to a dictionary suitable for JSON storage. .. deprecated:: - Use :meth:`model_dump` with ``mode="json"`` instead. This method + Use ``model_dump`` with ``mode="json"`` instead. This method will be removed in version 0.16.0. Returns: @@ -57,7 +57,7 @@ def from_dict(cls, data: dict) -> RetryEvent: Deserialize from a dictionary. .. deprecated:: - Use :meth:`model_validate` instead. This method will be removed + Use ``model_validate`` instead. This method will be removed in version 0.16.0. Args: diff --git a/pyrit/score/scorer_evaluation/scorer_metrics.py b/pyrit/score/scorer_evaluation/scorer_metrics.py index d44392cfd7..49e5252400 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics.py @@ -27,8 +27,8 @@ class ScorerMetrics: Base dataclass for storing scorer evaluation metrics. This class provides methods for serializing metrics to JSON strings (see - :meth:`to_json`) and loading them from JSON files on disk (see - :meth:`from_json_file`). + ``to_json``) and loading them from JSON files on disk (see + ``from_json_file``). Args: num_responses (int): Total number of responses evaluated. @@ -53,7 +53,7 @@ def to_json(self) -> str: Serialize this metrics instance to a JSON string. This is the canonical serialization entry point for ``ScorerMetrics`` and its - subclasses. Pair it with :meth:`from_json_file` (which reads a JSON file written + subclasses. Pair it with ``from_json_file`` (which reads a JSON file written from this string, optionally wrapped in a ``"metrics"`` key) for round-trip (de)serialization. @@ -98,10 +98,10 @@ def from_json_file(cls: type[T], file_path: Union[str, Path]) -> T: @classmethod def from_json(cls: type[T], file_path: Union[str, Path]) -> T: """ - Load a metrics instance from a JSON file (deprecated alias for :meth:`from_json_file`). + Load a metrics instance from a JSON file (deprecated alias for ``from_json_file``). The name ``from_json`` is misleading because it accepts a *file path*, not a JSON - string. Use :meth:`from_json_file` instead. + string. Use ``from_json_file`` instead. Args: file_path (Union[str, Path]): The path to the JSON file. diff --git a/tests/unit/build_scripts/test_check_no_rest_roles.py b/tests/unit/build_scripts/test_check_no_rest_roles.py new file mode 100644 index 0000000000..bbe413977d --- /dev/null +++ b/tests/unit/build_scripts/test_check_no_rest_roles.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pathlib import Path + +from build_scripts.check_no_rest_roles import main + + +def _write(tmp_path: Path, name: str, content: str) -> Path: + p = tmp_path / name + p.write_text(content, encoding="utf-8") + return p + + +def test_main_passes_on_clean_file(tmp_path: Path) -> None: + clean = _write( + tmp_path, + "clean.py", + '"""Returns a ``SeedPrompt`` instance and calls ``download_files_async``."""\n', + ) + assert main([str(clean)]) == 0 + + +def test_main_passes_when_no_python_files(tmp_path: Path) -> None: + md = _write(tmp_path, "notes.md", ":class:`Foo` is fine in markdown\n") + # Non-Python paths are skipped without inspection. + assert main([str(md)]) == 0 + + +def test_main_flags_class_role(tmp_path: Path, capsys) -> None: + bad = _write(tmp_path, "bad.py", '"""Returns a :class:`SeedPrompt` instance."""\n') + rc = main([str(bad)]) + assert rc == 1 + err = capsys.readouterr().out + assert "bad.py:1" in err + assert ":class:`SeedPrompt`" in err + + +def test_main_flags_func_meth_and_py_prefixed_roles(tmp_path: Path) -> None: + bad = _write( + tmp_path, + "bad.py", + '"""\nSee :func:`do_thing` and :meth:`Foo.bar` and :py:class:`X`.\n"""\n', + ) + assert main([str(bad)]) == 1 + + +def test_main_ignores_bare_colon_in_code(tmp_path: Path) -> None: + # ":key: value" pattern (e.g. Google docstring section header) should not match. + clean = _write( + tmp_path, + "clean.py", + '"""\nArgs:\n foo (int): the foo value.\n"""\n', + ) + assert main([str(clean)]) == 0 + + +def test_main_returns_zero_when_called_without_args() -> None: + assert main([]) == 0 diff --git a/tests/unit/build_scripts/test_gen_api_md.py b/tests/unit/build_scripts/test_gen_api_md.py new file mode 100644 index 0000000000..db4472d300 --- /dev/null +++ b/tests/unit/build_scripts/test_gen_api_md.py @@ -0,0 +1,543 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from build_scripts.gen_api_md import ( + SymbolEntry, + _build_symbol_index, + _class_anchor, + _format_bases, + _format_reexport_alias, + _format_reexport_target, + _function_anchor, + _method_anchor, + _process_docstring_text, + _rewrite_symbol_refs, + render_class, + render_function, + render_module, +) + + +def _fake_class(name: str, methods: list[str] | None = None) -> dict: + return { + "name": name, + "kind": "class", + "methods": [{"name": m, "kind": "function"} for m in (methods or [])], + } + + +def _fake_function(name: str) -> dict: + return {"name": name, "kind": "function"} + + +def _fake_module(name: str, members: list[dict]) -> dict: + return {"name": name, "kind": "module", "members": members} + + +def test_anchor_helpers_produce_unique_labels() -> None: + assert _class_anchor("pyrit.prompt_target", "PromptTarget") == "api-pyrit_prompt_target-PromptTarget" + assert _function_anchor("pyrit.common", "validate_log_level") == "api-pyrit_common-validate_log_level" + assert ( + _method_anchor("pyrit.prompt_target", "PromptTarget", "send_prompt_async") + == "api-pyrit_prompt_target-PromptTarget-send_prompt_async" + ) + + +def test_build_symbol_index_registers_classes_functions_and_methods() -> None: + modules = [ + _fake_module( + "pyrit.prompt_target", + [ + _fake_class("PromptTarget", methods=["send_prompt_async", "apply_capabilities"]), + _fake_function("limit_requests_per_minute"), + ], + ), + ] + index = _build_symbol_index(modules) + + # Short-name lookup + assert len(index["PromptTarget"]) == 1 + assert index["PromptTarget"][0].kind == "class" + assert index["PromptTarget"][0].anchor == "api-pyrit_prompt_target-PromptTarget" + + # Class.method lookup + assert len(index["PromptTarget.send_prompt_async"]) == 1 + assert index["PromptTarget.send_prompt_async"][0].anchor == "api-pyrit_prompt_target-PromptTarget-send_prompt_async" + + # FQN lookup + assert index["pyrit.prompt_target.PromptTarget"][0].kind == "class" + assert index["pyrit.prompt_target.limit_requests_per_minute"][0].kind == "function" + + +def test_build_symbol_index_skips_private_members() -> None: + modules = [ + _fake_module( + "pyrit.example", + [ + _fake_class("Public", methods=["do_thing", "_internal_helper"]), + _fake_function("_private_func"), + ], + ), + ] + index = _build_symbol_index(modules) + + assert "_internal_helper" not in index + assert "Public._internal_helper" not in index + assert "_private_func" not in index + assert "do_thing" in index + + +def test_build_symbol_index_marks_duplicates_as_ambiguous() -> None: + modules = [ + _fake_module("pyrit.first", [_fake_class("Scorer")]), + _fake_module("pyrit.second", [_fake_class("Scorer")]), + ] + index = _build_symbol_index(modules) + + assert len(index["Scorer"]) == 2 + # FQN entries stay distinct + assert len(index["pyrit.first.Scorer"]) == 1 + assert len(index["pyrit.second.Scorer"]) == 1 + + +def test_rewrite_symbol_refs_links_unique_class() -> None: + index = { + "SeedPrompt": [ + SymbolEntry( + module="pyrit.models", + kind="class", + name="SeedPrompt", + qualname="SeedPrompt", + anchor="api-pyrit_models-SeedPrompt", + ) + ] + } + out = _rewrite_symbol_refs("Returns a ``SeedPrompt`` instance.", index) + assert out == "Returns a [``SeedPrompt``](#api-pyrit_models-SeedPrompt) instance." + + +def test_rewrite_symbol_refs_handles_single_backticks() -> None: + index = {"Foo": [SymbolEntry(module="pyrit.x", kind="class", name="Foo", qualname="Foo", anchor="api-pyrit_x-Foo")]} + out = _rewrite_symbol_refs("See `Foo` for details.", index) + assert out == "See [`Foo`](#api-pyrit_x-Foo) for details." + + +def test_rewrite_symbol_refs_resolves_class_dot_method() -> None: + index = { + "PromptTarget.send_prompt_async": [ + SymbolEntry( + module="pyrit.prompt_target", + kind="method", + name="send_prompt_async", + qualname="PromptTarget.send_prompt_async", + anchor="api-pyrit_prompt_target-PromptTarget-send_prompt_async", + ) + ] + } + out = _rewrite_symbol_refs("Call ``PromptTarget.send_prompt_async`` to dispatch.", index) + assert "[``PromptTarget.send_prompt_async``]" in out + assert "#api-pyrit_prompt_target-PromptTarget-send_prompt_async" in out + + +def test_rewrite_symbol_refs_resolves_bare_method_with_current_class() -> None: + index = { + "PromptTarget.send_prompt_async": [ + SymbolEntry( + module="pyrit.prompt_target", + kind="method", + name="send_prompt_async", + qualname="PromptTarget.send_prompt_async", + anchor="api-pyrit_prompt_target-PromptTarget-send_prompt_async", + ) + ], + "send_prompt_async": [ + SymbolEntry( + module="pyrit.prompt_target", + kind="method", + name="send_prompt_async", + qualname="PromptTarget.send_prompt_async", + anchor="api-pyrit_prompt_target-PromptTarget-send_prompt_async", + ) + ], + } + out = _rewrite_symbol_refs("Then ``send_prompt_async`` is invoked.", index, current_class="PromptTarget") + assert "[``send_prompt_async``]" in out + + +def test_rewrite_symbol_refs_skips_ambiguous_names() -> None: + entry_a = SymbolEntry(module="pyrit.a", kind="class", name="Scorer", qualname="Scorer", anchor="api-pyrit_a-Scorer") + entry_b = SymbolEntry(module="pyrit.b", kind="class", name="Scorer", qualname="Scorer", anchor="api-pyrit_b-Scorer") + index = {"Scorer": [entry_a, entry_b]} + out = _rewrite_symbol_refs("Use ``Scorer``.", index) + assert out == "Use ``Scorer``." + + +def test_rewrite_symbol_refs_leaves_unknown_names_alone() -> None: + out = _rewrite_symbol_refs("This is ``True`` and ``None``.", {}) + assert out == "This is ``True`` and ``None``." + + +def test_rewrite_symbol_refs_resolves_fully_qualified_name() -> None: + entry = SymbolEntry( + module="pyrit.models", + kind="class", + name="SeedPrompt", + qualname="SeedPrompt", + anchor="api-pyrit_models-SeedPrompt", + ) + index = {"SeedPrompt": [entry], "pyrit.models.SeedPrompt": [entry]} + out = _rewrite_symbol_refs("Use ``pyrit.models.SeedPrompt`` here.", index) + assert "[``pyrit.models.SeedPrompt``](#api-pyrit_models-SeedPrompt)" in out + + +def test_rewrite_symbol_refs_preserves_fenced_code_blocks() -> None: + index = { + "SeedPrompt": [ + SymbolEntry( + module="pyrit.models", + kind="class", + name="SeedPrompt", + qualname="SeedPrompt", + anchor="api-pyrit_models-SeedPrompt", + ) + ] + } + text = ( + "Outside: ``SeedPrompt``.\n" + "```python\n" + "x = SeedPrompt()\n" + "# ``SeedPrompt`` should not be linked here\n" + "```\n" + "After: ``SeedPrompt``." + ) + out = _rewrite_symbol_refs(text, index) + assert "[``SeedPrompt``](#api-pyrit_models-SeedPrompt)" in out.split("```")[0] + assert "# ``SeedPrompt`` should not be linked here" in out + # The closing "After" sentence should also be rewritten + assert out.endswith("After: [``SeedPrompt``](#api-pyrit_models-SeedPrompt).") + + +def test_rewrite_symbol_refs_skips_existing_links() -> None: + index = {"Foo": [SymbolEntry(module="pyrit.x", kind="class", name="Foo", qualname="Foo", anchor="api-pyrit_x-Foo")]} + text = "Already-linked: [``Foo``](#api-pyrit_x-Foo)." + out = _rewrite_symbol_refs(text, index) + # No double-wrap + assert out == text + + +def test_rewrite_symbol_refs_handles_tilde_and_dotted_prefix() -> None: + entry = SymbolEntry( + module="pyrit.models", + kind="class", + name="SeedPrompt", + qualname="SeedPrompt", + anchor="api-pyrit_models-SeedPrompt", + ) + index = {"pyrit.models.SeedPrompt": [entry]} + out = _rewrite_symbol_refs("Tilde form ``~pyrit.models.SeedPrompt`` works.", index) + assert "(#api-pyrit_models-SeedPrompt)" in out + + +def test_rewrite_symbol_refs_empty_string_passthrough() -> None: + assert _rewrite_symbol_refs("", {}) == "" + assert _rewrite_symbol_refs(None, {}) is None # type: ignore[arg-type] + + +def test_process_docstring_text_protects_doctest_examples() -> None: + """The escape-then-rewrite order must wrap ``>>>`` blocks in fences + *before* the symbol rewriter runs, so a known PyRIT symbol that happens + to appear inside a doctest example stays as raw text instead of being + turned into a MyST link (which would break the code sample).""" + index = { + "SeedPrompt": [ + SymbolEntry( + module="pyrit.models", + kind="class", + name="SeedPrompt", + qualname="SeedPrompt", + anchor="api-pyrit_models-SeedPrompt", + ) + ] + } + text = ( + "Returns a ``SeedPrompt`` instance.\n" + "\n" + "Example:\n" + " >>> sp = SeedPrompt(value='hi')\n" + " >>> assert isinstance(sp, SeedPrompt)\n" + " >>> print(sp)\n" + "After the example, ``SeedPrompt`` is linkable again." + ) + out = _process_docstring_text(text, index, current_class=None) + assert out is not None + # Prose before the doctest is linked. + assert "[``SeedPrompt``](#api-pyrit_models-SeedPrompt) instance." in out + # Doctest contents are fenced and NOT turned into MyST links. + assert "```python" in out + assert ">>> sp = SeedPrompt(value='hi')" in out + assert "[SeedPrompt]" not in out # bare-word inside doctest stays bare + # Prose after the doctest is linked again. + assert out.endswith("After the example, [``SeedPrompt``](#api-pyrit_models-SeedPrompt) is linkable again.") + + +def test_render_function_emits_anchor_and_links_docstring_fields() -> None: + """End-to-end render path: a function with a linkable name in its + description, parameter description, returns description, and raises + description should produce a unique anchor label and MyST links + everywhere the symbol appears.""" + index = { + "PromptTarget": [ + SymbolEntry( + module="pyrit.prompt_target", + kind="class", + name="PromptTarget", + qualname="PromptTarget", + anchor="api-pyrit_prompt_target-PromptTarget", + ) + ] + } + func = { + "name": "build_target", + "kind": "function", + "is_async": False, + "signature": [{"name": "name", "type": "str", "kind": "positional or keyword"}], + "returns_annotation": "PromptTarget", + "docstring": { + "text": "Construct a ``PromptTarget`` from a name.", + "params": [ + {"name": "name", "type": "str", "desc": "Identifier for the ``PromptTarget``."}, + ], + "returns": [{"type": "PromptTarget", "desc": "The constructed ``PromptTarget``."}], + "raises": [{"type": "ValueError", "desc": "If no ``PromptTarget`` matches the name."}], + }, + } + out = render_function(func, module="pyrit.factories", symbol_index=index) + + # Anchor label is emitted for the function heading. + assert "(api-pyrit_factories-build_target)=" in out + # The function name still appears in the heading. + assert "### `build_target`" in out + # Every docstring field has been rewritten to link to the known symbol. + expected_link = "[``PromptTarget``](#api-pyrit_prompt_target-PromptTarget)" + assert out.count(expected_link) == 4 + + +def test_render_function_uses_method_anchor_when_class_name_given() -> None: + """Methods get a class-scoped anchor and the current_class context lets + the rewriter resolve bare same-class method references.""" + index = { + "PromptTarget.send_prompt_async": [ + SymbolEntry( + module="pyrit.prompt_target", + kind="method", + name="send_prompt_async", + qualname="PromptTarget.send_prompt_async", + anchor="api-pyrit_prompt_target-PromptTarget-send_prompt_async", + ) + ] + } + method = { + "name": "validate", + "kind": "function", + "signature": [], + "docstring": {"text": "Then ``send_prompt_async`` is invoked by the runtime."}, + } + out = render_function( + method, + heading_level="####", + module="pyrit.prompt_target", + class_name="PromptTarget", + symbol_index=index, + ) + + assert "(api-pyrit_prompt_target-PromptTarget-validate)=" in out + assert "#### `validate`" in out + assert "[``send_prompt_async``](#api-pyrit_prompt_target-PromptTarget-send_prompt_async)" in out + + +def _prompt_target_entry() -> SymbolEntry: + return SymbolEntry( + module="pyrit.prompt_target", + kind="class", + name="PromptTarget", + qualname="PromptTarget", + anchor="api-pyrit_prompt_target-PromptTarget", + ) + + +def test_format_bases_links_known_pyrit_base() -> None: + index = {"PromptTarget": [_prompt_target_entry()]} + out = _format_bases(["PromptTarget"], index) + assert out == "[`PromptTarget`](#api-pyrit_prompt_target-PromptTarget)" + + +def test_format_bases_keeps_external_base_as_plain_code_span() -> None: + """Bases not in the symbol index (stdlib types like ``str``/``Enum``) stay + as plain backtick code spans instead of being mangled into broken links.""" + out = _format_bases(["str", "Enum"], {}) + assert out == "`str`, `Enum`" + + +def test_format_bases_links_mixed_pyrit_and_external() -> None: + """A mix of resolvable and external bases produces a clean + comma-separated list with only the known names linked.""" + index = {"PromptTarget": [_prompt_target_entry()]} + out = _format_bases(["PromptTarget", "ABC", "Identifiable"], index) + assert out == "[`PromptTarget`](#api-pyrit_prompt_target-PromptTarget), `ABC`, `Identifiable`" + + +def test_format_bases_empty_or_none_returns_empty_string() -> None: + assert _format_bases([], {}) == "" + # Without a symbol index we still emit plain code spans. + assert _format_bases(["str"], None) == "`str`" + + +def test_render_class_emits_linked_bases_line() -> None: + """End-to-end: a class with a known PyRIT base renders the ``Bases:`` line + as a MyST link rather than a plain code span.""" + index = {"PromptTarget": [_prompt_target_entry()]} + cls = {"name": "MyTarget", "kind": "class", "bases": ["PromptTarget", "str"]} + out = render_class(cls, module="pyrit.factories", symbol_index=index) + + assert "(api-pyrit_factories-MyTarget)=" in out + assert "Bases: [`PromptTarget`](#api-pyrit_prompt_target-PromptTarget), `str`" in out + # No accidental wrapper backticks around the whole comma-joined list. + assert "Bases: `PromptTarget" not in out + + +def test_render_class_without_bases_omits_bases_line() -> None: + cls = {"name": "Standalone", "kind": "class", "bases": []} + out = render_class(cls, module="pyrit.misc", symbol_index={}) + assert "Bases:" not in out + + +def test_format_reexport_alias_prefers_module_qualified_lookup() -> None: + """The alias usually lives on the re-exporting module, so the FQN form + (``mod_name.alias_name``) is tried before the short name.""" + canonical = SymbolEntry( + module="pyrit.models", + kind="class", + name="SeedPrompt", + qualname="SeedPrompt", + anchor="api-pyrit_models-SeedPrompt", + ) + re_exported = SymbolEntry( + module="pyrit", + kind="class", + name="SeedPrompt", + qualname="SeedPrompt", + anchor="api-pyrit-SeedPrompt", + ) + index = { + "SeedPrompt": [canonical, re_exported], + "pyrit.SeedPrompt": [re_exported], + "pyrit.models.SeedPrompt": [canonical], + } + out = _format_reexport_alias("pyrit", "SeedPrompt", index) + # Picks the alias's own page rather than the canonical definition page. + assert out == "[`SeedPrompt`](#api-pyrit-SeedPrompt)" + + +def test_format_reexport_alias_falls_back_to_short_name() -> None: + """When no module-qualified entry exists, the short-name rewriter still + links unambiguous names so the re-export remains navigable.""" + entry = _prompt_target_entry() + index = {"PromptTarget": [entry], "pyrit.prompt_target.PromptTarget": [entry]} + out = _format_reexport_alias("pyrit", "PromptTarget", index) + assert out == "[`PromptTarget`](#api-pyrit_prompt_target-PromptTarget)" + + +def test_format_reexport_alias_leaves_unresolvable_name_plain() -> None: + out = _format_reexport_alias("pyrit.misc", "Mystery", {}) + assert out == "`Mystery`" + + +def test_format_reexport_target_links_fqn_when_indexed() -> None: + entry = _prompt_target_entry() + index = {"pyrit.prompt_target.PromptTarget": [entry]} + out = _format_reexport_target("pyrit.prompt_target.PromptTarget", index) + assert out == "[`pyrit.prompt_target.PromptTarget`](#api-pyrit_prompt_target-PromptTarget)" + + +def test_format_reexport_target_leaves_unresolvable_target_plain() -> None: + out = _format_reexport_target("pyrit.unknown.Symbol", {}) + assert out == "`pyrit.unknown.Symbol`" + + +def test_format_reexport_target_empty_returns_empty() -> None: + assert _format_reexport_target("", {}) == "" + + +def test_render_module_links_both_reexport_sides() -> None: + """End-to-end: a module with an alias whose FQN target is in the index + renders both the alias name and the target as MyST links.""" + canonical = _prompt_target_entry() + re_exported = SymbolEntry( + module="pyrit", + kind="class", + name="PromptTarget", + qualname="PromptTarget", + anchor="api-pyrit-PromptTarget", + ) + index = { + "PromptTarget": [canonical, re_exported], + "pyrit.PromptTarget": [re_exported], + "pyrit.prompt_target.PromptTarget": [canonical], + } + module = _fake_module( + "pyrit", + [{"name": "PromptTarget", "kind": "alias", "target": "pyrit.prompt_target.PromptTarget"}], + ) + out = render_module(module, symbol_index=index) + + assert "## Re-exports" in out + assert "[`PromptTarget`](#api-pyrit-PromptTarget)" in out + assert "[`pyrit.prompt_target.PromptTarget`](#api-pyrit_prompt_target-PromptTarget)" in out + assert " → " in out + + +def test_render_module_leaves_unresolvable_reexport_target_plain() -> None: + """When a re-export target points outside the index (e.g. a fake/external + path), it stays as a plain code span instead of becoming a broken link.""" + canonical = _prompt_target_entry() + index = { + "PromptTarget": [canonical], + "pyrit.prompt_target.PromptTarget": [canonical], + } + module = _fake_module( + "pyrit", + [{"name": "Mystery", "kind": "alias", "target": "pyrit.unknown.Mystery"}], + ) + out = render_module(module, symbol_index=index) + + assert "- `Mystery` → `pyrit.unknown.Mystery`" in out + # Plain code spans, not links. + assert "(#" not in out.split("## Re-exports")[1] + + +def test_render_module_emits_module_level_anchor_in_frontmatter() -> None: + """The page-level label is emitted as a frontmatter ``label:`` field so + cross-page references like ``[](#api-pyrit_prompt_target)`` target the + page itself. MyST consumes the H1 as the page title and discards any + label placed in the body before it, so frontmatter is the only reliable + place to bind a page-level anchor.""" + module = _fake_module("pyrit.prompt_target", [_fake_class("PromptTarget")]) + out = render_module(module, symbol_index={}) + + assert "label: api-pyrit_prompt_target" in out + # Heading still present. + assert "# pyrit.prompt_target" in out + # Frontmatter still wraps the page. + assert out.startswith("---") + assert "short_title: prompt_target" in out + # Label is inside the frontmatter, not after it. + fm_end = out.index("---\n", 4) # skip the opening "---" + assert out.index("label: api-pyrit_prompt_target") < fm_end + + +def test_render_module_label_uses_module_slug_for_nested_packages() -> None: + module = _fake_module("pyrit.executor.attack", [_fake_class("AttackStrategy")]) + out = render_module(module, symbol_index={}) + + assert "label: api-pyrit_executor_attack" in out + assert "# pyrit.executor.attack" in out