Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
QuantizeClampArgumentsPass,
)
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
from .fuse_concat_pass import FuseConcatPass # noqa
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
from .fuse_constant_ops_pass import ( # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
DecorateFp32toInt32CastingPass,
FoldAndAnnotateQParamsPass,
FuseBatchNorm2dPass,
FuseConcatPass,
FuseConsecutiveConcatShapesPass,
FuseConsecutiveRescalesPass,
FuseConstantArgsPass,
Expand Down Expand Up @@ -486,6 +487,7 @@ def _tosa_pipeline(
# Aten -> TOSA transformation passes
self.add_passes(
[
FuseConcatPass(),
RewriteUpsamplePass(),
RewriteConvPass(exported_program),
RewriteMatmulPass(),
Expand Down
270 changes: 270 additions & 0 deletions backends/arm/_passes/fuse_concat_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Set, Type

import torch.fx
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

logger = logging.getLogger(__name__)


class FuseConcatPass(ArmPass):
"""Eliminate redundant concat (cat) operations via graph pattern matching.

Inspired by Espresso's concat elimination techniques
(bolt/nn/espresso/transforms/remove_nops.py), this pass recognizes and
removes concat operations that can be proven to produce no useful data
movement. Eliminating these at the FX/TOSA level prevents Vela from
generating MemoryCopy operations on the Ethos-U NPU.

Five patterns are handled:

1. Single-input concat: cat([x], dim) is a no-op; replace with x.
2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is
a slice_copy that extracts exactly one original input, replace it
with the corresponding concat input directly.
3. Slice-then-concat (full): if cat([slice(x, d, s0, e0),
slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous
slices covering the full source dimension), replace with x.
4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a
slice_copy whose range falls entirely within one original input,
replace it with an adjusted slice on that input directly.
5. Slice-then-concat (partial): if contiguous slices of the same tensor
are concatenated but cover only a sub-range of the source dimension,
replace with a single slice on the source.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

cat_ops = {
exir_ops.edge.aten.cat.default,
}
slice_op = exir_ops.edge.aten.slice_copy.Tensor

def call(self, graph_module: torch.fx.GraphModule):
modified = False
graph = graph_module.graph

for node in list(graph.nodes):
if node.op != "call_function" or node.target not in self.cat_ops:
continue
if node.graph is None:
continue

if self._eliminate_single_input_cat(node):
modified = True
continue

if self._eliminate_cat_then_slice(node):
modified = True
continue

if self._eliminate_slice_then_cat(node):
modified = True
continue

if modified:
graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)

# ------------------------------------------------------------------
# Pattern 1: single-input cat
# ------------------------------------------------------------------
@staticmethod
def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool:
inputs = cat_node.args[0]
if not isinstance(inputs, (list, tuple)) or len(inputs) != 1:
return False
cat_node.replace_all_uses_with(inputs[0])
logger.debug("Eliminated single-input cat: %s", cat_node.name)
return True

# ------------------------------------------------------------------
# Patterns 2 & 4: cat -> slice (exact input or sub-range of input)
# ------------------------------------------------------------------
@staticmethod
def _eliminate_cat_then_slice(
cat_node: torch.fx.Node,
) -> bool:
cat_inputs = cat_node.args[0]
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
return False

cat_dim = cat_node.args[1] if len(cat_node.args) > 1 else 0
output_rank = len(get_first_fake_tensor(cat_node).shape)
cat_dim = (cat_dim + output_rank) % output_rank

users = list(cat_node.users.keys())
if not users:
return False

# Every user must be a slice_copy on the same dim with step=1.
for user in users:
if user.target != FuseConcatPass.slice_op:
return False
if user.args[0] is not cat_node:
return False
slice_dim = user.args[1] if len(user.args) > 1 else 0
slice_dim = (slice_dim + output_rank) % output_rank
if slice_dim != cat_dim:
return False
slice_step = user.args[4] if len(user.args) > 4 else 1
if slice_step != 1:
return False

# Build the offset map for each concat input along cat_dim.
offsets = []
offset = 0
for inp in cat_inputs:
inp_shape = get_first_fake_tensor(inp).shape
size = inp_shape[cat_dim]
offsets.append((offset, offset + size, inp))
offset += size

# For each user, try exact match (Pattern 2) then sub-range (Pattern 4).
# Users that cross input boundaries are skipped.
replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = []

for user in users:
s_start = user.args[2] if len(user.args) > 2 else 0
s_end = user.args[3] if len(user.args) > 3 else offset
s_end = min(s_end, offset)

for o_start, o_end, inp in offsets:
if s_start == o_start and s_end == o_end:
# Pattern 2: exact match — replace slice with input.
replacements.append((user, inp))
break
if s_start >= o_start and s_end <= o_end:
# Pattern 4: sub-range — replace with slice on original.
adj_start = s_start - o_start
adj_end = s_end - o_start
graph = cat_node.graph
with graph.inserting_before(user):
new_slice = graph.call_function(
FuseConcatPass.slice_op,
(inp, cat_dim, adj_start, adj_end),
)
new_slice.meta = user.meta.copy()
replacements.append((user, new_slice))
break

if not replacements:
return False

for old_node, new_node in replacements:
old_node.replace_all_uses_with(new_node)

logger.debug(
"Eliminated cat-then-slice pattern: %s (%d slices redirected)",
cat_node.name,
len(replacements),
)
return True

# ------------------------------------------------------------------
# Patterns 3 & 5: slice -> cat (contiguous slices, full or partial)
# ------------------------------------------------------------------
@staticmethod
def _eliminate_slice_then_cat(
cat_node: torch.fx.Node,
) -> bool:
cat_inputs = cat_node.args[0]
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
return False

cat_dim = cat_node.args[1] if len(cat_node.args) > 1 else 0
output_rank = len(get_first_fake_tensor(cat_node).shape)
cat_dim = (cat_dim + output_rank) % output_rank

# All inputs must be slice_copy on the same source tensor and dim,
# with step=1.
source_node = None
for inp in cat_inputs:
if not isinstance(inp, torch.fx.Node):
return False
if inp.target != FuseConcatPass.slice_op:
return False
slice_source = inp.args[0]
slice_dim = inp.args[1] if len(inp.args) > 1 else 0
inp_rank = len(get_first_fake_tensor(inp).shape)
slice_dim = (slice_dim + inp_rank) % inp_rank
if slice_dim != cat_dim:
return False
slice_step = inp.args[4] if len(inp.args) > 4 else 1
if slice_step != 1:
return False
if source_node is None:
source_node = slice_source
elif slice_source is not source_node:
return False

if source_node is None:
return False

source_shape = get_first_fake_tensor(source_node).shape
source_dim_size = source_shape[cat_dim]

# Verify slices are contiguous (but not necessarily starting at 0).
first_inp = cat_inputs[0]
first_start = first_inp.args[2] if len(first_inp.args) > 2 else 0
expected_start = first_start
for inp in cat_inputs:
s_start = inp.args[2] if len(inp.args) > 2 else 0
s_end = inp.args[3] if len(inp.args) > 3 else source_dim_size
s_end = min(s_end, source_dim_size)
if s_start != expected_start:
return False
expected_start = s_end
last_end = expected_start

# Verify output shape matches expectations.
cat_shape = get_first_fake_tensor(cat_node).shape

if first_start == 0 and last_end == source_dim_size:
# Pattern 3: full coverage — replace with source tensor.
if list(cat_shape) != list(source_shape):
return False
cat_node.replace_all_uses_with(source_node)
logger.debug(
"Eliminated slice-then-cat (full): %s -> %s",
cat_node.name,
source_node.name,
)
else:
# Pattern 5: partial coverage — replace with single slice.
expected_dim_size = last_end - first_start
if cat_shape[cat_dim] != expected_dim_size:
return False
for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)):
if i != cat_dim and cs != ss: # dims must match except for cat_dim
return False
graph = cat_node.graph
with graph.inserting_before(cat_node):
new_slice = graph.call_function(
FuseConcatPass.slice_op,
(source_node, cat_dim, first_start, last_end),
)
new_slice.meta = cat_node.meta.copy()
cat_node.replace_all_uses_with(new_slice)
logger.debug(
"Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)",
cat_node.name,
source_node.name,
cat_dim,
first_start,
last_end,
)
return True
Loading
Loading