Skip to content
This repository was archived by the owner on Oct 1, 2024. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 29 additions & 49 deletions tests/test_rewrite.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import ast

from tests.utilities import testable_test
from tests.utilities import testable_test, failing_assertion
from ward import fixture, test
from ward._rewrite import (
RewriteAssert,
get_assertion_msg,
is_binary_comparison,
is_comparison_type,
make_call_node,
rewrite_assertions_in_tests,
)
from ward.expect import TestAssertionFailure, raises
from ward.testing import Test, each


Expand All @@ -34,37 +34,6 @@ def as_dict(node):
return node


@testable_test
def passing_fn():
assert 1 == 1


@testable_test
def failing_fn():
assert 1 == 2


@fixture
def passing():
yield Test(fn=passing_fn, module_name="m", id="id-pass")


@fixture
def failing():
yield Test(fn=failing_fn, module_name="m", id="id-fail")


@test("rewrite_assertions_in_tests returns all tests, keeping metadata")
def _(p=passing, f=failing):
in_tests = [p, f]
out_tests = rewrite_assertions_in_tests(in_tests)

def meta(test):
return test.description, test.id, test.module_name, test.fn.ward_meta

assert [meta(test) for test in in_tests] == [meta(test) for test in out_tests]


@test("RewriteAssert.visit_Assert doesn't transform `{src}`")
def _(
src=each(
Expand Down Expand Up @@ -121,6 +90,33 @@ def _(
assert out_tree.value.args[1].id == "y"
assert out_tree.value.args[2].s == ""

@test("This test suite's assertions are themselves rewritten")
def _():
with raises(TestAssertionFailure):
assert 1 == 2
with raises(TestAssertionFailure):
assert 1 != 1
with raises(TestAssertionFailure):
assert 1 in ()
with raises(TestAssertionFailure):
assert 1 not in (1,)
with raises(TestAssertionFailure):
assert None is Ellipsis
with raises(TestAssertionFailure):
assert None is not None
with raises(TestAssertionFailure):
assert 2 < 1
with raises(TestAssertionFailure):
assert 2 <= 1
with raises(TestAssertionFailure):
assert 1 > 2
with raises(TestAssertionFailure):
assert 1 >= 2

@test("Non-test modules' assertions aren't rewritten")
def _():
with raises(AssertionError):
failing_assertion()

@test("RewriteAssert.visit_Assert transforms `{src}`")
def _(src="assert 1 == 2, 'msg'"):
Expand Down Expand Up @@ -210,19 +206,3 @@ def _():
@test("test with indentation level of 2")
def _():
assert 2 + 3 == 5


@test("rewriter finds correct function when there is a lambda in an each")
def _():
@testable_test
def _(x=each(lambda: 5)):
assert x == 5

t = Test(fn=_, module_name="m")

rewritten = rewrite_assertions_in_tests([t])[0]

# https://github.com/darrenburns/ward/issues/169
# The assertion rewriter thought the lambda function stored in co_consts was the test function,
# so it was rebuilding the test function using the lambda as the test instead of the original function.
assert rewritten.fn.__code__.co_name != "<lambda>"
4 changes: 4 additions & 0 deletions tests/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def testable_test(func):
testable_test.path = FORCE_TEST_PATH # type: ignore[attr-defined]


def failing_assertion():
assert 1 == 2


@fixture
def dummy_fixture():
"""
Expand Down
3 changes: 2 additions & 1 deletion ward/_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from cucumber_tag_expressions.model import Expression

from ward._errors import CollectionError
from ward._rewrite import exec_module
from ward._testing import COLLECTED_TESTS, is_test_module_name
from ward._utilities import get_absolute_path
from ward.fixtures import Fixture
Expand Down Expand Up @@ -149,7 +150,7 @@ def load_modules(modules: Iterable[pkgutil.ModuleInfo]) -> List[ModuleType]:
if pkg_data.pkg_root not in sys.path:
sys.path.append(str(pkg_data.pkg_root))
m.__package__ = pkg_data.pkg_name
m.__loader__.exec_module(m)
exec_module(m)
loaded_modules.append(m)

return loaded_modules
Expand Down
61 changes: 8 additions & 53 deletions ward/_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import textwrap
import types
from pathlib import Path
from typing import Iterable, List

from ward.expect import (
Expand Down Expand Up @@ -87,57 +88,11 @@ def visit_Assert(self, node): # noqa: C901 - no chance to reduce complexity
return node


def rewrite_assertions_in_tests(tests: Iterable[Test]) -> List[Test]:
return [rewrite_assertion(test) for test in tests]


def rewrite_assertion(test: Test) -> Test:
# Get the old code and code object
code_lines, line_no = inspect.getsourcelines(test.fn)

code = "".join(code_lines)
indents = textwrap._leading_whitespace_re.findall(code)
col_offset = len(indents[0]) if len(indents) > 0 else 0
code = textwrap.dedent(code)
code_obj = test.fn.__code__

# Rewrite the AST of the code
tree = ast.parse(code)
ast.increment_lineno(tree, line_no - 1)

def exec_module(module: types.ModuleType):
filename = module.__spec__.origin
code = module.__loader__.get_source(module.__name__)
tree = ast.parse(code, filename=filename)
new_tree = RewriteAssert().visit(tree)

if sys.version_info[:2] < (3, 11):
# We dedented the code so that it was a valid tree, now re-apply the indent
for child in ast.walk(new_tree):
if hasattr(child, "col_offset"):
child.col_offset = getattr(child, "col_offset", 0) + col_offset

# Reconstruct the test function
new_mod_code_obj = compile(new_tree, code_obj.co_filename, "exec")

# TODO: This probably isn't correct for nested closures
clo_glob = {}
if test.fn.__closure__:
clo_glob = test.fn.__closure__[0].cell_contents.__globals__

# Look through the new module,
# find the code object with the same name as the original code object,
# and build a new function with the injected assert functions added to the global namespace.
# Filtering on the code object name prevents finding other kinds of code objects,
# like lambdas stored directly in test function arguments.
for const in new_mod_code_obj.co_consts:
if isinstance(const, types.CodeType) and const.co_name == code_obj.co_name:
new_test_func = types.FunctionType(
const,
{**assert_func_namespace, **test.fn.__globals__, **clo_glob},
test.fn.__name__,
test.fn.__defaults__,
)
new_test_func.ward_meta = test.fn.ward_meta
return Test(
**{k: vars(test)[k] for k in vars(test) if k != "fn"},
fn=new_test_func,
)

return test
code = compile(new_tree, filename, "exec", dont_inherit=True)
module.__dict__.update(assert_func_namespace)
exec(code, module.__dict__)
5 changes: 1 addition & 4 deletions ward/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from ward._config import set_defaults_from_config
from ward._debug import init_breakpointhooks
from ward._rewrite import rewrite_assertions_in_tests
from ward._suite import Suite
from ward._terminal import (
SessionPrelude,
Expand Down Expand Up @@ -204,11 +203,9 @@ def test(
if config.order == "random":
shuffle(filtered_tests)

tests = rewrite_assertions_in_tests(filtered_tests)

time_to_collect_secs = default_timer() - start_run

suite = Suite(tests=tests)
suite = Suite(tests=filtered_tests)
test_results = suite.generate_test_runs(
dry_run=dry_run, capture_output=capture_output
)
Expand Down