Skip to content

Commit 345a3dc

Browse files
committed
Integrate Automated QDQ placement tool - part 2
Signed-off-by: Will Guo <[email protected]>
1 parent 5cc2a54 commit 345a3dc

File tree

11 files changed

+5880
-0
lines changed

11 files changed

+5880
-0
lines changed

modelopt/onnx/op_types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,22 @@ def is_data_dependent_shape_op(op_type: str):
303303
"NonZero",
304304
"RoiAlign",
305305
]
306+
307+
308+
def get_symmetric_ops():
309+
"""Returns set of commutative/symmetric operations where operand order doesn't matter."""
310+
return {
311+
"Add",
312+
"Mul",
313+
"And",
314+
"Or",
315+
"Xor",
316+
"Equal",
317+
"Max",
318+
"Min",
319+
"Sum",
320+
"Mean",
321+
"BitwiseAnd",
322+
"BitwiseOr",
323+
"BitwiseXor",
324+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Pattern-Based Q/DQ Autotuning for ONNX Models.
17+
18+
This package provides automated optimization of Quantize/Dequantize (Q/DQ) node placement
19+
in ONNX computation graphs to minimize TensorRT inference latency. It uses pattern-based
20+
region analysis to efficiently explore and optimize Q/DQ insertion strategies.
21+
22+
**Key Features:**
23+
24+
- **Automated Region Discovery**: Hierarchical decomposition of computation graphs into
25+
LEAF and COMPOSITE regions with automatic pattern identification
26+
27+
- **Pattern-Based Optimization**: Groups structurally-similar regions and optimizes them
28+
together, making the process efficient and consistent
29+
30+
- **TensorRT Performance Measurement**: Direct integration with TensorRT Python API for
31+
accurate latency profiling of each Q/DQ configuration
32+
33+
- **State Management**: Checkpoint/resume capability for long-running optimizations with
34+
incremental state saving after each region
35+
36+
- **Pattern Cache**: Warm-start optimization using learned schemes from previous runs,
37+
enabling transfer learning across models
38+
39+
**Core Components:**
40+
41+
Autotuner Classes:
42+
- QDQAutotuner: Main autotuner with automatic hierarchical region discovery
43+
- QDQAutotunerBase: Base class for custom region identification strategies
44+
45+
Region Management:
46+
- Region: Hierarchical subgraph representation (nodes + children)
47+
- RegionType: Enumeration (LEAF, COMPOSITE, ROOT)
48+
- CombinedRegionSearch: Two-phase region discovery (partitioning + refinement)
49+
- RegionPattern: Structural pattern analysis and matching for region grouping
50+
51+
Q/DQ Insertion Points:
52+
- InsertionScheme: Collection of Q/DQ insertion points for a region pattern
53+
- NodeInputInsertionPoint: Q/DQ insertion at specific node inputs
54+
- ChildRegionInputInsertionPoint: Q/DQ insertion at child region input boundaries
55+
- RegionOutputInsertionPoint: Q/DQ insertion at region output boundaries
56+
57+
Configuration & State:
58+
- Config: Autotuning parameters (quant type, thresholds, verbosity)
59+
- PatternCache: Top-performing schemes indexed by pattern (warm-start)
60+
- PatternSchemes: Scheme collection and measurement results for a pattern
61+
62+
Benchmarking:
63+
- Benchmark: Abstract base class for model benchmarking
64+
- TensorRTPyBenchmark: Benchmark using TensorRT Python API (recommended)
65+
- TrtExecBenchmark: Benchmark using trtexec command-line tool (legacy)
66+
67+
**Quick Start:**
68+
69+
>>> from modelopt.onnx.quantization.autotune import QDQAutotuner, Config
70+
>>> import onnx
71+
>>> # Load model and initialize autotuner
72+
>>> model = onnx.load("model.onnx")
73+
>>> autotuner = QDQAutotuner(model)
74+
>>> # Configure autotuning parameters
75+
>>> config = Config(default_quant_type="int8")
76+
>>> autotuner.initialize(config)
77+
>>> # Generate and test Q/DQ schemes
78+
>>> # (see workflows.region_pattern_autotuning_workflow for complete example)
79+
80+
**Command-Line Interface:**
81+
82+
The package can be run directly as a module:
83+
84+
$ python -m modelopt.onnx.quantization.autotune --model model.onnx --output ./output
85+
$ python -m modelopt.onnx.quantization.autotune --model model.onnx --quant-type fp8
86+
87+
**See Also:**
88+
89+
- workflows.region_pattern_autotuning_workflow: Complete end-to-end optimization
90+
- QDQAutotuner: Main autotuner class documentation
91+
- RegionPattern: Pattern matching and signature computation
92+
"""
93+
94+
# Core data structures
95+
from .common import (
96+
AutotunerError,
97+
AutotunerNotInitializedError,
98+
Config,
99+
InsertionScheme,
100+
InvalidSchemeError,
101+
PatternCache,
102+
PatternSchemes,
103+
Region,
104+
RegionError,
105+
RegionType,
106+
)
107+
108+
# Insertion points (from dedicated module)
109+
from .insertion_points import (
110+
ChildRegionInputInsertionPoint,
111+
NodeInputInsertionPoint,
112+
RegionOutputInsertionPoint,
113+
ResolvedInsertionPoint,
114+
)
115+
116+
# Pattern analysis
117+
from .region_pattern import RegionPattern
118+
119+
# Region search
120+
from .region_search import CombinedRegionSearch
121+
122+
# Public API
123+
__all__ = [
124+
# Exceptions
125+
"AutotunerError",
126+
"AutotunerNotInitializedError",
127+
"ChildRegionInputInsertionPoint",
128+
"CombinedRegionSearch",
129+
# Configuration and state
130+
"Config",
131+
# Q/DQ insertion
132+
"InsertionScheme",
133+
"InvalidSchemeError",
134+
"NodeInputInsertionPoint",
135+
"ResolvedInsertionPoint",
136+
"PatternCache",
137+
"PatternSchemes",
138+
# Region classes
139+
"Region",
140+
"RegionError",
141+
"RegionOutputInsertionPoint",
142+
"RegionPattern",
143+
"RegionType",
144+
]

0 commit comments

Comments
 (0)