diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index a72a7b70f7..22ed82ac8b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -6,7 +6,6 @@ import numpy as np import torch -from tensorrt import ITensor as TRTTensor from torch.fx.node import Argument, Node, Target from torch_tensorrt import ENABLED_FEATURES from torch_tensorrt._features import needs_not_tensorrt_rtx @@ -28,6 +27,8 @@ ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM +from tensorrt import ITensor as TRTTensor + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -577,6 +578,16 @@ def index_has_bool_indices( for ind in index: if ind is not None: val = ind.meta.get("val") + if val is None and ind.op == "get_attr": + # fx.symbolic_trace embeds constant tensors as get_attr nodes + # without meta["val"]; fetch the actual tensor from the module. + try: + attr = ind.graph.owning_module + for part in ind.target.split("."): + attr = getattr(attr, part) + val = attr + except AttributeError: + pass if val is not None and val.dtype == torch.bool: return True return False diff --git a/tests/py/dynamo/conversion/test_index_bool_split_aten.py b/tests/py/dynamo/conversion/test_index_bool_split_aten.py index 466260fc72..3d8b9e66eb 100644 --- a/tests/py/dynamo/conversion/test_index_bool_split_aten.py +++ b/tests/py/dynamo/conversion/test_index_bool_split_aten.py @@ -6,6 +6,7 @@ 3. Boolean-indexed `aten.index.Tensor` routes to the converter WITH output allocator. 4. Both paths produce correct results. """ + import unittest from unittest.mock import MagicMock @@ -60,9 +61,7 @@ def test_none_with_bool_indices_returns_true(self): def test_mixed_int_and_bool_returns_true(self): """If any index is bool, the function should return True.""" - node = _make_index_node( - [torch.tensor([0, 1]), torch.tensor([True, False])] - ) + node = _make_index_node([torch.tensor([0, 1]), torch.tensor([True, False])]) self.assertTrue(index_has_bool_indices(node)) def test_all_none_returns_false(self): diff --git a/uv.lock b/uv.lock index 82459b609a..f37e1184e3 100644 --- a/uv.lock +++ b/uv.lock @@ -32,9 +32,6 @@ required-markers = [ "python_full_version < '3.14' and platform_machine == 'AMD64' and sys_platform == 'win32'", ] -[options] -prerelease-mode = "allow" - [[package]] name = "accelerate" version = "1.12.0"