@@ -50,30 +50,39 @@ def extract_symbolic_shape_expressions(
5050 return None
5151
5252 input_val = input_node .meta ["val" ]
53- if not isinstance (input_val , torch .Tensor ):
53+ logger .debug (
54+ f"Input node '{ input_node .name } ': type={ type (input_val )} , val={ input_val } "
55+ )
56+ if isinstance (input_val , torch .Tensor ):
57+ shape_exprs = []
58+ for dim_size in input_val .shape :
59+ if isinstance (dim_size , torch .SymInt ):
60+ shape_exprs .append (dim_size .node .expr )
61+ else :
62+ shape_exprs .append (int (dim_size ))
63+
64+ input_info .append (
65+ {
66+ "shape_exprs" : shape_exprs ,
67+ "dtype" : input_val .dtype ,
68+ "name" : input_node .name ,
69+ }
70+ )
71+ elif isinstance (input_val , (torch .SymInt , torch .SymFloat , int , float , bool )):
72+ input_info .append (
73+ {
74+ "shape_exprs" : [],
75+ "dtype" : None ,
76+ "name" : input_node .name ,
77+ "is_scalar" : True ,
78+ }
79+ )
80+ else :
5481 logger .warning (
55- "When processing symbolic shapes for TensorRT engine, input is not a tensor "
82+ f "When processing symbolic shapes for TensorRT engine, unsupported input type: { type ( input_val ) } "
5683 )
5784 return None
5885
59- # Extract shape as sympy expressions (can be pickled)
60- shape_exprs = []
61- for dim_size in input_val .shape :
62- if isinstance (dim_size , torch .SymInt ):
63- # Store the sympy expression, which can be pickled
64- shape_exprs .append (dim_size .node .expr )
65- else :
66- # Store concrete integer
67- shape_exprs .append (int (dim_size ))
68-
69- input_info .append (
70- {
71- "shape_exprs" : shape_exprs ,
72- "dtype" : input_val .dtype ,
73- "name" : input_node .name ,
74- }
75- )
76-
7786 # Extract output values from output node
7887 output_args = output_node .args [0 ]
7988 if not isinstance (output_args , (tuple , list )):
@@ -89,29 +98,34 @@ def extract_symbolic_shape_expressions(
8998 return None
9099
91100 out_val = out_arg .meta ["val" ]
92- if not isinstance (out_val , torch .Tensor ):
101+ if isinstance (out_val , torch .Tensor ):
102+ shape_exprs = []
103+ for dim_size in out_val .shape :
104+ if isinstance (dim_size , torch .SymInt ):
105+ shape_exprs .append (dim_size .node .expr )
106+ else :
107+ shape_exprs .append (int (dim_size ))
108+
109+ output_info .append (
110+ {
111+ "shape_exprs" : shape_exprs ,
112+ "dtype" : out_val .dtype ,
113+ }
114+ )
115+ elif isinstance (out_val , (torch .SymInt , torch .SymFloat , int , float , bool )):
116+ output_info .append (
117+ {
118+ "shape_exprs" : [],
119+ "dtype" : None ,
120+ "is_scalar" : True ,
121+ }
122+ )
123+ else :
93124 logger .warning (
94- "When processing symbolic shapes for TensorRT engine, output is not a tensor "
125+ f "When processing symbolic shapes for TensorRT engine, unsupported output type: { type ( out_val ) } "
95126 )
96127 return None
97128
98- # Extract shape as sympy expressions (can be pickled)
99- shape_exprs = []
100- for dim_size in out_val .shape :
101- if isinstance (dim_size , torch .SymInt ):
102- # Store the sympy expression, which can be pickled
103- shape_exprs .append (dim_size .node .expr )
104- else :
105- # Store concrete integer
106- shape_exprs .append (int (dim_size ))
107-
108- output_info .append (
109- {
110- "shape_exprs" : shape_exprs ,
111- "dtype" : out_val .dtype ,
112- }
113- )
114-
115129 if not output_info :
116130 return None
117131
0 commit comments