Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ test = [
]

[project.scripts]
flash = "runpod_flash.cli.main:app"
flash = "runpod_flash.cli.entrypoint:main"

[build-system]
requires = ["setuptools>=42", "wheel"]
Expand Down
36 changes: 36 additions & 0 deletions src/runpod_flash/cli/entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Thin CLI entrypoint that catches corrupted credentials at import time.

The runpod package reads ~/.runpod/config.toml at import time (in its
__init__.py). If that file contains invalid TOML, the import raises a
TOMLDecodeError before any Flash error handling can run. This wrapper
catches that and surfaces a clean error message.
"""

import sys


def main() -> None:
"""Entry point for the ``flash`` console script."""
try:
from runpod_flash.cli.main import app
except ValueError as exc:
# TOML decode errors from toml/tomli/tomllib are ValueError subclasses.
# The runpod package calls a TOML loader at import time; a corrupted
# ~/.runpod/config.toml triggers this before Flash code executes.
exc_type = type(exc)
exc_module = getattr(exc_type, "__module__", "").lower()
is_toml_decode_error = exc_type.__name__ == "TOMLDecodeError" and (
exc_module.startswith("toml")
or exc_module.startswith("tomli")
or exc_module.startswith("tomllib")
)
if is_toml_decode_error:
print(
"Error: ~/.runpod/config.toml is corrupted and cannot be parsed.\n"
"Run 'flash login' to re-authenticate, or delete the file and retry.",
file=sys.stderr,
)
raise SystemExit(1) from None
raise

app()
77 changes: 77 additions & 0 deletions tests/unit/cli/test_entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Tests for the CLI entrypoint wrapper that catches corrupted credentials."""

import sys
from unittest.mock import MagicMock, patch

import pytest

from runpod_flash.cli.entrypoint import main


class TestEntrypoint:
"""Tests for runpod_flash.cli.entrypoint.main."""

def test_normal_import_runs_app(self):
"""When import succeeds, the Typer app is invoked."""
mock_app = MagicMock()
mock_module = MagicMock()
mock_module.app = mock_app

with patch.dict(sys.modules, {"runpod_flash.cli.main": mock_module}):
main()

mock_app.assert_called_once()

def test_corrupted_toml_shows_clean_error(self, capsys):
"""Import-time TOMLDecodeError surfaces a clean message, not a traceback."""
# Create a ValueError whose class looks like a TOML decode error.
# tomli.TOMLDecodeError is a ValueError subclass with module "tomli._parser".
toml_exc_cls = type(
"TOMLDecodeError", (ValueError,), {"__module__": "tomli._parser"}
)
toml_error = toml_exc_cls("Invalid value at line 1 col 9")

# Remove the module from cache so the import inside main() re-executes
saved = sys.modules.pop("runpod_flash.cli.main", None)
try:
with patch.dict(sys.modules, {"runpod_flash.cli.main": None}):
# Patch __import__ to raise when the entrypoint tries to import main
real_import = __import__

def patched_import(name, *args, **kwargs):
if name == "runpod_flash.cli.main":
raise toml_error
return real_import(name, *args, **kwargs)

with patch("builtins.__import__", side_effect=patched_import):
with pytest.raises(SystemExit) as exc_info:
main()

assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "corrupted" in captured.err
assert "flash login" in captured.err
finally:
if saved is not None:
sys.modules["runpod_flash.cli.main"] = saved

def test_non_toml_value_error_propagates(self):
"""A ValueError unrelated to TOML is not caught."""
saved = sys.modules.pop("runpod_flash.cli.main", None)
try:
with patch.dict(sys.modules, {"runpod_flash.cli.main": None}):
real_import = __import__

def patched_import(name, *args, **kwargs):
if name == "runpod_flash.cli.main":
raise ValueError("something completely different")
return real_import(name, *args, **kwargs)

with patch("builtins.__import__", side_effect=patched_import):
with pytest.raises(
ValueError, match="something completely different"
):
main()
finally:
if saved is not None:
sys.modules["runpod_flash.cli.main"] = saved
Loading