Skip to content

Commit 5b8b2ea

Browse files
apboselanluo-nvidia
authored andcommitted
use of inputTensorShapeValues after freeing led to garbage memory being read in C++ runtime, adding test cases
1 parent 1950445 commit 5b8b2ea

File tree

3 files changed

+231
-12
lines changed

3 files changed

+231
-12
lines changed

core/runtime/execute_engine.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ void setup_input_tensors(
9696
std::vector<at::Tensor> inputs,
9797
c10::intrusive_ptr<TRTEngine> compiled_engine,
9898
bool cudagraphs_enabled,
99-
bool need_cudagraphs_record) {
100-
// this is a buffer to store shape tensor input addresses throughout the runtime scope
101-
std::list<std::vector<int64_t>> inputShapeTensorValues;
99+
bool need_cudagraphs_record,
100+
std::list<std::vector<int64_t>>& inputShapeTensorValues) {
102101
std::list<at::Tensor> formatted_inputs(compiled_engine->num_io.first);
103102

104103
for (size_t i = 0; i < inputs.size(); i++) {
@@ -115,12 +114,10 @@ void setup_input_tensors(
115114

116115
auto dims = core::util::toDims(inputs[i].sizes());
117116
auto shape = core::util::toVec(dims);
118-
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
117+
bool is_shape_tensor = compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str());
118+
LOG_DEBUG("Input Name: " << name << " Shape: " << dims << " isShapeInferenceIO: " << is_shape_tensor);
119119

120-
if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
121-
// Shape tensor inputs are casted to int64 explicitly.
122-
// Refer to
123-
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
120+
if (is_shape_tensor) {
124121
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64);
125122
std::vector<int64_t> inputs_cpu_vec(
126123
input_cpu.data_ptr<int64_t>(), input_cpu.data_ptr<int64_t>() + input_cpu.numel());
@@ -233,6 +230,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
233230

234231
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
235232

233+
// Shape tensor CPU buffers must outlive inferShapes() and enqueueV3()
234+
std::list<std::vector<int64_t>> inputShapeTensorValues;
235+
236236
// Intialize inputs and outputs to be available throughout the succeeding scopes
237237
{ // Input Setup
238238
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
@@ -241,7 +241,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
241241
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
242242
}
243243

244-
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
244+
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues);
245245
// Check if input shapes can be inferred.
246246
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
247247
std::vector<char const*> names(io_size);
@@ -364,14 +364,17 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
364364
};
365365

366366
auto run_output_allocator = [&]() {
367+
// Shape tensor CPU buffers must outlive inferShapes() and enqueueV3()
368+
std::list<std::vector<int64_t>> inputShapeTensorValues;
369+
367370
{ // Input Setup
368371
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
369372
if (compiled_engine->profile_execution) {
370373
input_profiler_guard =
371374
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
372375
}
373376

374-
setup_input_tensors(inputs, compiled_engine, false, false);
377+
setup_input_tensors(inputs, compiled_engine, false, false, inputShapeTensorValues);
375378
// Check if input shapes can be inferred.
376379
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
377380
std::vector<char const*> names(io_size);

py/torch_tensorrt/dynamo/conversion/_symbolic_shape_capture.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,16 @@ def extract_symbolic_shape_expressions(
6969
}
7070
)
7171
elif isinstance(input_val, (torch.SymInt, torch.SymFloat, int, float, bool)):
72+
if isinstance(input_val, (torch.SymInt, int)):
73+
scalar_dtype = torch.int64
74+
elif isinstance(input_val, (torch.SymFloat, float)):
75+
scalar_dtype = torch.float64
76+
else:
77+
scalar_dtype = torch.bool
7278
input_info.append(
7379
{
7480
"shape_exprs": [],
75-
"dtype": None,
81+
"dtype": scalar_dtype,
7682
"name": input_node.name,
7783
"is_scalar": True,
7884
}
@@ -113,10 +119,16 @@ def extract_symbolic_shape_expressions(
113119
}
114120
)
115121
elif isinstance(out_val, (torch.SymInt, torch.SymFloat, int, float, bool)):
122+
if isinstance(out_val, (torch.SymInt, int)):
123+
scalar_dtype = torch.int64
124+
elif isinstance(out_val, (torch.SymFloat, float)):
125+
scalar_dtype = torch.float64
126+
else:
127+
scalar_dtype = torch.bool
116128
output_info.append(
117129
{
118130
"shape_exprs": [],
119-
"dtype": None,
131+
"dtype": scalar_dtype,
120132
"is_scalar": True,
121133
}
122134
)
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""
2+
Tests for SymInt scalar input handling in symbolic shape capture and TRT compilation.
3+
4+
These tests verify that when Dynamo partitions an FX graph such that a SymInt
5+
(e.g., from targets.size(0)) becomes a bare scalar placeholder input to the TRT
6+
subgraph, the symbolic shape extraction and compilation succeed.
7+
8+
This covers the fix in _symbolic_shape_capture.py where non-tensor inputs
9+
(SymInt, int, float, bool) are handled gracefully instead of aborting extraction.
10+
"""
11+
12+
import unittest
13+
14+
import pytest
15+
import torch
16+
import torch_tensorrt as torchtrt
17+
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
18+
19+
assertions = unittest.TestCase()
20+
21+
22+
@pytest.mark.unit
23+
@pytest.mark.parametrize("use_python_runtime", [True, False])
24+
def test_symint_from_size_used_in_reshape(use_python_runtime):
25+
"""
26+
Test that a SymInt derived from tensor.size(0) can be used in reshape
27+
when it becomes a scalar placeholder input to the TRT subgraph.
28+
29+
This is the core pattern from issue #4107: targets.size(0) produces a
30+
SymInt that Dynamo passes as a bare scalar input to the TRT partition,
31+
which then uses it in a reshape operation.
32+
"""
33+
34+
class Model(torch.nn.Module):
35+
def forward(self, x, targets):
36+
B = targets.size(0)
37+
y = x.reshape(B, -1)
38+
return y
39+
40+
model = Model().eval().cuda()
41+
42+
x = torch.randn(16, 64).cuda()
43+
targets = torch.randint(0, 10, (16, 1), dtype=torch.int64).cuda()
44+
45+
torch._dynamo.mark_dynamic(x, 0, min=1, max=2048)
46+
torch._dynamo.mark_dynamic(targets, 0, min=1, max=2048)
47+
48+
compile_spec = {
49+
"enabled_precisions": {torch.float},
50+
"min_block_size": 1,
51+
"pass_through_build_failures": True,
52+
"use_python_runtime": use_python_runtime,
53+
}
54+
55+
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
56+
57+
output_ref = model(x, targets)
58+
output_trt = trt_model(x, targets)
59+
60+
cos_sim = cosine_similarity(output_ref, output_trt)
61+
assertions.assertTrue(
62+
cos_sim > COSINE_THRESHOLD,
63+
msg=f"SymInt reshape test (python_runtime={use_python_runtime}) failed. Cosine sim: {cos_sim}",
64+
)
65+
66+
torch._dynamo.reset()
67+
68+
69+
@pytest.mark.unit
70+
@pytest.mark.parametrize("use_python_runtime", [True, False])
71+
def test_scalar_tensor_input(use_python_runtime):
72+
"""
73+
Test that a 0-dim scalar tensor input (e.g., cache_length) is handled
74+
correctly during symbolic shape extraction and TRT compilation.
75+
"""
76+
77+
class Model(torch.nn.Module):
78+
def forward(self, x, offset):
79+
return x + offset
80+
81+
model = Model().eval().cuda()
82+
83+
x = torch.randn(16, 64).cuda()
84+
offset = torch.tensor(5.0).cuda()
85+
86+
compile_spec = {
87+
"enabled_precisions": {torch.float},
88+
"min_block_size": 1,
89+
"pass_through_build_failures": True,
90+
"use_python_runtime": use_python_runtime,
91+
}
92+
93+
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
94+
95+
output_ref = model(x, offset)
96+
output_trt = trt_model(x, offset)
97+
98+
cos_sim = cosine_similarity(output_ref, output_trt)
99+
assertions.assertTrue(
100+
cos_sim > COSINE_THRESHOLD,
101+
msg=f"Scalar tensor input test (python_runtime={use_python_runtime}) failed. Cosine sim: {cos_sim}",
102+
)
103+
104+
torch._dynamo.reset()
105+
106+
107+
@pytest.mark.unit
108+
@pytest.mark.parametrize("use_python_runtime", [True, False])
109+
def test_symint_with_index_and_reshape(use_python_runtime):
110+
"""
111+
Full reproduction of issue #4107 pattern: symbolic size from int64 tensor,
112+
used with index operation and reshape.
113+
114+
Model does:
115+
1. B = targets.size(0) → SymInt
116+
2. idx = cache_length + arange(1) → int64 index tensor
117+
3. y = x[:, idx, :] → gather with int64 index
118+
4. z = y.reshape(B, 1, -1, 2) → reshape using SymInt
119+
"""
120+
121+
class TestModule(torch.nn.Module):
122+
def forward(self, x, targets, cache_length):
123+
B = targets.size(0)
124+
idx = cache_length + torch.arange(1, device=x.device)
125+
y = x[:, idx, :]
126+
z = y.reshape(B, 1, -1, 2)
127+
return z
128+
129+
model = TestModule().eval().cuda()
130+
131+
B, S, D = 16, 128, 1024
132+
x = torch.randn(B, S, D).cuda()
133+
targets = torch.randint(0, 10, (B, 1), dtype=torch.int64).cuda()
134+
cache_length = torch.tensor(0, dtype=torch.int64).cuda()
135+
136+
torch._dynamo.mark_dynamic(targets, 0, min=1, max=2048)
137+
torch._dynamo.mark_dynamic(x, 0, min=1, max=2048)
138+
139+
compile_spec = {
140+
"enabled_precisions": {torch.float, torch.half},
141+
"min_block_size": 1,
142+
"truncate_double": True,
143+
"pass_through_build_failures": True,
144+
"use_python_runtime": use_python_runtime,
145+
}
146+
147+
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
148+
149+
output_ref = model(x, targets, cache_length)
150+
output_trt = trt_model(x, targets, cache_length)
151+
152+
cos_sim = cosine_similarity(output_ref, output_trt)
153+
assertions.assertTrue(
154+
cos_sim > COSINE_THRESHOLD,
155+
msg=f"Issue 4107 repro test (python_runtime={use_python_runtime}) failed. Cosine sim: {cos_sim}",
156+
)
157+
158+
torch._dynamo.reset()
159+
160+
161+
@pytest.mark.unit
162+
@pytest.mark.parametrize("use_python_runtime", [True, False])
163+
def test_symint_with_different_batch_sizes(use_python_runtime):
164+
"""
165+
Test that after compilation with a SymInt scalar input, the model
166+
produces correct results with different batch sizes.
167+
"""
168+
169+
class Model(torch.nn.Module):
170+
def forward(self, x, targets):
171+
B = targets.size(0)
172+
return x.reshape(B, 2, -1)
173+
174+
model = Model().eval().cuda()
175+
176+
x = torch.randn(8, 64).cuda()
177+
targets = torch.randint(0, 10, (8, 1), dtype=torch.int64).cuda()
178+
179+
torch._dynamo.mark_dynamic(x, 0, min=1, max=2048)
180+
torch._dynamo.mark_dynamic(targets, 0, min=1, max=2048)
181+
182+
compile_spec = {
183+
"enabled_precisions": {torch.float},
184+
"min_block_size": 1,
185+
"pass_through_build_failures": True,
186+
"use_python_runtime": use_python_runtime,
187+
}
188+
189+
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
190+
191+
for batch_size in [4, 8, 16]:
192+
x_test = torch.randn(batch_size, 64).cuda()
193+
targets_test = torch.randint(0, 10, (batch_size, 1), dtype=torch.int64).cuda()
194+
195+
output_ref = model(x_test, targets_test)
196+
output_trt = trt_model(x_test, targets_test)
197+
198+
cos_sim = cosine_similarity(output_ref, output_trt)
199+
assertions.assertTrue(
200+
cos_sim > COSINE_THRESHOLD,
201+
msg=f"Varying batch size test (python_runtime={use_python_runtime}) failed at B={batch_size}. Cosine sim: {cos_sim}",
202+
)
203+
204+
torch._dynamo.reset()

0 commit comments

Comments
 (0)