Skip to content

Commit e11eaa3

Browse files
fix the llm test failing issue + add back cu126/cu128 (#4118)
1 parent 0b92ffb commit e11eaa3

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

.github/scripts/filter-matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# rtx 1.2 currently only supports cu129 and cu130
1717
rtx_cuda_versions: List[str] = ["cu129", "cu130"]
1818
# trt 10.14.1 currently only supports cu129 and cu130
19-
trt_cuda_versions: List[str] = ["cu129", "cu130"]
19+
trt_cuda_versions: List[str] = ["cu126", "cu128", "cu129", "cu130"]
2020

2121
jetpack_container_image: str = "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
2222
sbsa_container_image: str = "quay.io/pypa/manylinux_2_39_aarch64"

py/torch_tensorrt/dynamo/lowering/passes/_FakeTensorUpdater.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
)
1818
from torch.utils._ordered_set import OrderedSet
1919

20+
# Try to import reinplace module - may not be available in all PyTorch builds
21+
try:
22+
import importlib
23+
24+
reinplace_module = importlib.import_module("torch._inductor.fx_passes.reinplace")
25+
_generalized_scatter = getattr(reinplace_module, "_generalized_scatter", None)
26+
except Exception:
27+
_generalized_scatter = None
28+
2029

2130
# Adapted from torch._inductor.fx_utils.FakeTensorUpdater
2231
class FakeTensorUpdater:
@@ -159,8 +168,10 @@ def should_process_node(node: torch.fx.Node) -> bool:
159168
return node.op == "call_function" and (
160169
isinstance(node.target, torch._ops.OpOverload)
161170
or node.target is operator.getitem
162-
or node.target
163-
is torch._inductor.fx_passes.reinplace._generalized_scatter
171+
or (
172+
_generalized_scatter is not None
173+
and node.target is _generalized_scatter
174+
)
164175
)
165176

166177
to_process = OrderedSet[int]()

0 commit comments

Comments
 (0)