Skip to content

Commit f91806e

Browse files
ryan-monroefacebook-github-bot
authored andcommitted
Add FuseConcatPass to eliminate redundant concat ops
Summary: Concat (torch.cat) in the Gen2 Executorch ARM/Ethos-U stack is lowered to TOSA CONCAT, which Vela then converts to N x MemoryCopy operations — real DMA data movement on the NPU. This pass eliminates concat operations that can be proven unnecessary at the FX graph level, preventing Vela from generating MemoryCopy ops entirely. Inspired by Espresso's concat elimination techniques (bolt/nn/espresso/transforms/remove_nops.py), three patterns are handled: 1. Single-input concat: cat([x]) is a no-op, replaced with x. 2. Concat-then-slice: if every consumer of cat([a, b, ...]) is a slice_copy that extracts exactly one original input, bypass both. 3. Slice-then-concat: if contiguous slices of the same tensor are concatenated back, the result is the original tensor. Differential Revision: D97667069
1 parent 411ede2 commit f91806e

File tree

4 files changed

+466
-0
lines changed

4 files changed

+466
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
QuantizeClampArgumentsPass,
103103
)
104104
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
105+
from .fuse_concat_pass import FuseConcatPass # noqa
105106
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
106107
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
107108
from .fuse_constant_ops_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
DecorateFp32toInt32CastingPass,
9999
FoldAndAnnotateQParamsPass,
100100
FuseBatchNorm2dPass,
101+
FuseConcatPass,
101102
FuseConsecutiveConcatShapesPass,
102103
FuseConsecutiveRescalesPass,
103104
FuseConstantArgsPass,
@@ -486,6 +487,7 @@ def _tosa_pipeline(
486487
# Aten -> TOSA transformation passes
487488
self.add_passes(
488489
[
490+
FuseConcatPass(),
489491
RewriteUpsamplePass(),
490492
RewriteConvPass(exported_program),
491493
RewriteMatmulPass(),
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
from typing import Set, Type
10+
11+
import torch.fx
12+
from executorch.backends.arm._passes import ArmPass
13+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class FuseConcatPass(ArmPass):
21+
"""Eliminate redundant concat (cat) operations via graph pattern matching.
22+
23+
Inspired by Espresso's concat elimination techniques
24+
(bolt/nn/espresso/transforms/remove_nops.py), this pass recognizes and
25+
removes concat operations that can be proven to produce no useful data
26+
movement. Eliminating these at the FX/TOSA level prevents Vela from
27+
generating MemoryCopy operations on the Ethos-U NPU.
28+
29+
Five patterns are handled:
30+
31+
1. Single-input concat: cat([x], dim) is a no-op; replace with x.
32+
2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is
33+
a slice_copy that extracts exactly one original input, replace it
34+
with the corresponding concat input directly.
35+
3. Slice-then-concat (full): if cat([slice(x, d, s0, e0),
36+
slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous
37+
slices covering the full source dimension), replace with x.
38+
4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a
39+
slice_copy whose range falls entirely within one original input,
40+
replace it with an adjusted slice on that input directly.
41+
5. Slice-then-concat (partial): if contiguous slices of the same tensor
42+
are concatenated but cover only a sub-range of the source dimension,
43+
replace with a single slice on the source.
44+
"""
45+
46+
_passes_required_after: Set[Type[ExportPass]] = set()
47+
48+
cat_ops = {
49+
exir_ops.edge.aten.cat.default,
50+
}
51+
slice_op = exir_ops.edge.aten.slice_copy.Tensor
52+
53+
def call(self, graph_module: torch.fx.GraphModule):
54+
modified = False
55+
graph = graph_module.graph
56+
57+
for node in list(graph.nodes):
58+
if node.op != "call_function" or node.target not in self.cat_ops:
59+
continue
60+
if node.graph is None:
61+
continue
62+
63+
if self._eliminate_single_input_cat(node):
64+
modified = True
65+
continue
66+
67+
if self._eliminate_cat_then_slice(node):
68+
modified = True
69+
continue
70+
71+
if self._eliminate_slice_then_cat(node):
72+
modified = True
73+
continue
74+
75+
if modified:
76+
graph.eliminate_dead_code()
77+
graph_module.recompile()
78+
graph_module = super().call(graph_module).graph_module
79+
80+
return PassResult(graph_module, modified)
81+
82+
# ------------------------------------------------------------------
83+
# Pattern 1: single-input cat
84+
# ------------------------------------------------------------------
85+
@staticmethod
86+
def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool:
87+
inputs = cat_node.args[0]
88+
if not isinstance(inputs, (list, tuple)) or len(inputs) != 1:
89+
return False
90+
cat_node.replace_all_uses_with(inputs[0])
91+
logger.debug("Eliminated single-input cat: %s", cat_node.name)
92+
return True
93+
94+
# ------------------------------------------------------------------
95+
# Patterns 2 & 4: cat -> slice (exact input or sub-range of input)
96+
# ------------------------------------------------------------------
97+
@staticmethod
98+
def _eliminate_cat_then_slice(
99+
cat_node: torch.fx.Node,
100+
) -> bool:
101+
cat_inputs = cat_node.args[0]
102+
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
103+
return False
104+
105+
cat_dim = cat_node.args[1] if len(cat_node.args) > 1 else 0
106+
output_rank = len(get_first_fake_tensor(cat_node).shape)
107+
cat_dim = (cat_dim + output_rank) % output_rank
108+
109+
users = list(cat_node.users.keys())
110+
if not users:
111+
return False
112+
113+
# Every user must be a slice_copy on the same dim with step=1.
114+
for user in users:
115+
if user.target != FuseConcatPass.slice_op:
116+
return False
117+
if user.args[0] is not cat_node:
118+
return False
119+
slice_dim = user.args[1] if len(user.args) > 1 else 0
120+
slice_dim = (slice_dim + output_rank) % output_rank
121+
if slice_dim != cat_dim:
122+
return False
123+
slice_step = user.args[4] if len(user.args) > 4 else 1
124+
if slice_step != 1:
125+
return False
126+
127+
# Build the offset map for each concat input along cat_dim.
128+
offsets = []
129+
offset = 0
130+
for inp in cat_inputs:
131+
inp_shape = get_first_fake_tensor(inp).shape
132+
size = inp_shape[cat_dim]
133+
offsets.append((offset, offset + size, inp))
134+
offset += size
135+
136+
# For each user, try exact match (Pattern 2) then sub-range (Pattern 4).
137+
# Users that cross input boundaries are skipped.
138+
replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = []
139+
140+
for user in users:
141+
s_start = user.args[2] if len(user.args) > 2 else 0
142+
s_end = user.args[3] if len(user.args) > 3 else offset
143+
s_end = min(s_end, offset)
144+
145+
for o_start, o_end, inp in offsets:
146+
if s_start == o_start and s_end == o_end:
147+
# Pattern 2: exact match — replace slice with input.
148+
replacements.append((user, inp))
149+
break
150+
if s_start >= o_start and s_end <= o_end:
151+
# Pattern 4: sub-range — replace with slice on original.
152+
adj_start = s_start - o_start
153+
adj_end = s_end - o_start
154+
graph = cat_node.graph
155+
with graph.inserting_before(user):
156+
new_slice = graph.call_function(
157+
FuseConcatPass.slice_op,
158+
(inp, cat_dim, adj_start, adj_end),
159+
)
160+
new_slice.meta = user.meta.copy()
161+
replacements.append((user, new_slice))
162+
break
163+
164+
if not replacements:
165+
return False
166+
167+
for old_node, new_node in replacements:
168+
old_node.replace_all_uses_with(new_node)
169+
170+
logger.debug(
171+
"Eliminated cat-then-slice pattern: %s (%d slices redirected)",
172+
cat_node.name,
173+
len(replacements),
174+
)
175+
return True
176+
177+
# ------------------------------------------------------------------
178+
# Patterns 3 & 5: slice -> cat (contiguous slices, full or partial)
179+
# ------------------------------------------------------------------
180+
@staticmethod
181+
def _eliminate_slice_then_cat(
182+
cat_node: torch.fx.Node,
183+
) -> bool:
184+
cat_inputs = cat_node.args[0]
185+
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
186+
return False
187+
188+
cat_dim = cat_node.args[1] if len(cat_node.args) > 1 else 0
189+
output_rank = len(get_first_fake_tensor(cat_node).shape)
190+
cat_dim = (cat_dim + output_rank) % output_rank
191+
192+
# All inputs must be slice_copy on the same source tensor and dim,
193+
# with step=1.
194+
source_node = None
195+
for inp in cat_inputs:
196+
if not isinstance(inp, torch.fx.Node):
197+
return False
198+
if inp.target != FuseConcatPass.slice_op:
199+
return False
200+
slice_source = inp.args[0]
201+
slice_dim = inp.args[1] if len(inp.args) > 1 else 0
202+
inp_rank = len(get_first_fake_tensor(inp).shape)
203+
slice_dim = (slice_dim + inp_rank) % inp_rank
204+
if slice_dim != cat_dim:
205+
return False
206+
slice_step = inp.args[4] if len(inp.args) > 4 else 1
207+
if slice_step != 1:
208+
return False
209+
if source_node is None:
210+
source_node = slice_source
211+
elif slice_source is not source_node:
212+
return False
213+
214+
if source_node is None:
215+
return False
216+
217+
source_shape = get_first_fake_tensor(source_node).shape
218+
source_dim_size = source_shape[cat_dim]
219+
220+
# Verify slices are contiguous (but not necessarily starting at 0).
221+
first_inp = cat_inputs[0]
222+
first_start = first_inp.args[2] if len(first_inp.args) > 2 else 0
223+
expected_start = first_start
224+
for inp in cat_inputs:
225+
s_start = inp.args[2] if len(inp.args) > 2 else 0
226+
s_end = inp.args[3] if len(inp.args) > 3 else source_dim_size
227+
s_end = min(s_end, source_dim_size)
228+
if s_start != expected_start:
229+
return False
230+
expected_start = s_end
231+
last_end = expected_start
232+
233+
# Verify output shape matches expectations.
234+
cat_shape = get_first_fake_tensor(cat_node).shape
235+
236+
if first_start == 0 and last_end == source_dim_size:
237+
# Pattern 3: full coverage — replace with source tensor.
238+
if list(cat_shape) != list(source_shape):
239+
return False
240+
cat_node.replace_all_uses_with(source_node)
241+
logger.debug(
242+
"Eliminated slice-then-cat (full): %s -> %s",
243+
cat_node.name,
244+
source_node.name,
245+
)
246+
else:
247+
# Pattern 5: partial coverage — replace with single slice.
248+
expected_dim_size = last_end - first_start
249+
if cat_shape[cat_dim] != expected_dim_size:
250+
return False
251+
for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)):
252+
if i != cat_dim and cs != ss: # dims must match except for cat_dim
253+
return False
254+
graph = cat_node.graph
255+
with graph.inserting_before(cat_node):
256+
new_slice = graph.call_function(
257+
FuseConcatPass.slice_op,
258+
(source_node, cat_dim, first_start, last_end),
259+
)
260+
new_slice.meta = cat_node.meta.copy()
261+
cat_node.replace_all_uses_with(new_slice)
262+
logger.debug(
263+
"Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)",
264+
cat_node.name,
265+
source_node.name,
266+
cat_dim,
267+
first_start,
268+
last_end,
269+
)
270+
return True

0 commit comments

Comments
 (0)