Skip to content

Commit 1950445

Browse files
apboselanluo-nvidia
authored andcommitted
handle symbolic shape for non tensor inputs in symbolic shape extraction
1 parent 58d53c0 commit 1950445

File tree

1 file changed

+53
-39
lines changed

1 file changed

+53
-39
lines changed

py/torch_tensorrt/dynamo/conversion/_symbolic_shape_capture.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,39 @@ def extract_symbolic_shape_expressions(
5050
return None
5151

5252
input_val = input_node.meta["val"]
53-
if not isinstance(input_val, torch.Tensor):
53+
logger.debug(
54+
f"Input node '{input_node.name}': type={type(input_val)}, val={input_val}"
55+
)
56+
if isinstance(input_val, torch.Tensor):
57+
shape_exprs = []
58+
for dim_size in input_val.shape:
59+
if isinstance(dim_size, torch.SymInt):
60+
shape_exprs.append(dim_size.node.expr)
61+
else:
62+
shape_exprs.append(int(dim_size))
63+
64+
input_info.append(
65+
{
66+
"shape_exprs": shape_exprs,
67+
"dtype": input_val.dtype,
68+
"name": input_node.name,
69+
}
70+
)
71+
elif isinstance(input_val, (torch.SymInt, torch.SymFloat, int, float, bool)):
72+
input_info.append(
73+
{
74+
"shape_exprs": [],
75+
"dtype": None,
76+
"name": input_node.name,
77+
"is_scalar": True,
78+
}
79+
)
80+
else:
5481
logger.warning(
55-
"When processing symbolic shapes for TensorRT engine, input is not a tensor"
82+
f"When processing symbolic shapes for TensorRT engine, unsupported input type: {type(input_val)}"
5683
)
5784
return None
5885

59-
# Extract shape as sympy expressions (can be pickled)
60-
shape_exprs = []
61-
for dim_size in input_val.shape:
62-
if isinstance(dim_size, torch.SymInt):
63-
# Store the sympy expression, which can be pickled
64-
shape_exprs.append(dim_size.node.expr)
65-
else:
66-
# Store concrete integer
67-
shape_exprs.append(int(dim_size))
68-
69-
input_info.append(
70-
{
71-
"shape_exprs": shape_exprs,
72-
"dtype": input_val.dtype,
73-
"name": input_node.name,
74-
}
75-
)
76-
7786
# Extract output values from output node
7887
output_args = output_node.args[0]
7988
if not isinstance(output_args, (tuple, list)):
@@ -89,29 +98,34 @@ def extract_symbolic_shape_expressions(
8998
return None
9099

91100
out_val = out_arg.meta["val"]
92-
if not isinstance(out_val, torch.Tensor):
101+
if isinstance(out_val, torch.Tensor):
102+
shape_exprs = []
103+
for dim_size in out_val.shape:
104+
if isinstance(dim_size, torch.SymInt):
105+
shape_exprs.append(dim_size.node.expr)
106+
else:
107+
shape_exprs.append(int(dim_size))
108+
109+
output_info.append(
110+
{
111+
"shape_exprs": shape_exprs,
112+
"dtype": out_val.dtype,
113+
}
114+
)
115+
elif isinstance(out_val, (torch.SymInt, torch.SymFloat, int, float, bool)):
116+
output_info.append(
117+
{
118+
"shape_exprs": [],
119+
"dtype": None,
120+
"is_scalar": True,
121+
}
122+
)
123+
else:
93124
logger.warning(
94-
"When processing symbolic shapes for TensorRT engine, output is not a tensor"
125+
f"When processing symbolic shapes for TensorRT engine, unsupported output type: {type(out_val)}"
95126
)
96127
return None
97128

98-
# Extract shape as sympy expressions (can be pickled)
99-
shape_exprs = []
100-
for dim_size in out_val.shape:
101-
if isinstance(dim_size, torch.SymInt):
102-
# Store the sympy expression, which can be pickled
103-
shape_exprs.append(dim_size.node.expr)
104-
else:
105-
# Store concrete integer
106-
shape_exprs.append(int(dim_size))
107-
108-
output_info.append(
109-
{
110-
"shape_exprs": shape_exprs,
111-
"dtype": out_val.dtype,
112-
}
113-
)
114-
115129
if not output_info:
116130
return None
117131

0 commit comments

Comments
 (0)