From f91806e651cda8908265154ed5cbc06b98282a4c Mon Sep 17 00:00:00 2001 From: Ryan Monroe Date: Fri, 10 Apr 2026 17:06:15 -0700 Subject: [PATCH] Add FuseConcatPass to eliminate redundant concat ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Concat (torch.cat) in the Gen2 Executorch ARM/Ethos-U stack is lowered to TOSA CONCAT, which Vela then converts to N x MemoryCopy operations — real DMA data movement on the NPU. This pass eliminates concat operations that can be proven unnecessary at the FX graph level, preventing Vela from generating MemoryCopy ops entirely. Inspired by Espresso's concat elimination techniques (bolt/nn/espresso/transforms/remove_nops.py), three patterns are handled: 1. Single-input concat: cat([x]) is a no-op, replaced with x. 2. Concat-then-slice: if every consumer of cat([a, b, ...]) is a slice_copy that extracts exactly one original input, bypass both. 3. Slice-then-concat: if contiguous slices of the same tensor are concatenated back, the result is the original tensor. Differential Revision: D97667069 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/fuse_concat_pass.py | 270 ++++++++++++++++++ .../arm/test/passes/test_fuse_concat_pass.py | 193 +++++++++++++ 4 files changed, 466 insertions(+) create mode 100644 backends/arm/_passes/fuse_concat_pass.py create mode 100644 backends/arm/test/passes/test_fuse_concat_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1bd18de581d..b6689658f64 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -102,6 +102,7 @@ QuantizeClampArgumentsPass, ) from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_concat_pass import FuseConcatPass # noqa from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa from .fuse_constant_ops_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 31cb7a2e2c7..1af04b71084 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -98,6 +98,7 @@ DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, + FuseConcatPass, FuseConsecutiveConcatShapesPass, FuseConsecutiveRescalesPass, FuseConstantArgsPass, @@ -486,6 +487,7 @@ def _tosa_pipeline( # Aten -> TOSA transformation passes self.add_passes( [ + FuseConcatPass(), RewriteUpsamplePass(), RewriteConvPass(exported_program), RewriteMatmulPass(), diff --git a/backends/arm/_passes/fuse_concat_pass.py b/backends/arm/_passes/fuse_concat_pass.py new file mode 100644 index 00000000000..d18b0014e8f --- /dev/null +++ b/backends/arm/_passes/fuse_concat_pass.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Set, Type + +import torch.fx +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +logger = logging.getLogger(__name__) + + +class FuseConcatPass(ArmPass): + """Eliminate redundant concat (cat) operations via graph pattern matching. + + Inspired by Espresso's concat elimination techniques + (bolt/nn/espresso/transforms/remove_nops.py), this pass recognizes and + removes concat operations that can be proven to produce no useful data + movement. Eliminating these at the FX/TOSA level prevents Vela from + generating MemoryCopy operations on the Ethos-U NPU. + + Five patterns are handled: + + 1. Single-input concat: cat([x], dim) is a no-op; replace with x. + 2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is + a slice_copy that extracts exactly one original input, replace it + with the corresponding concat input directly. + 3. Slice-then-concat (full): if cat([slice(x, d, s0, e0), + slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous + slices covering the full source dimension), replace with x. + 4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a + slice_copy whose range falls entirely within one original input, + replace it with an adjusted slice on that input directly. + 5. Slice-then-concat (partial): if contiguous slices of the same tensor + are concatenated but cover only a sub-range of the source dimension, + replace with a single slice on the source. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + cat_ops = { + exir_ops.edge.aten.cat.default, + } + slice_op = exir_ops.edge.aten.slice_copy.Tensor + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + graph = graph_module.graph + + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in self.cat_ops: + continue + if node.graph is None: + continue + + if self._eliminate_single_input_cat(node): + modified = True + continue + + if self._eliminate_cat_then_slice(node): + modified = True + continue + + if self._eliminate_slice_then_cat(node): + modified = True + continue + + if modified: + graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) + + # ------------------------------------------------------------------ + # Pattern 1: single-input cat + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool: + inputs = cat_node.args[0] + if not isinstance(inputs, (list, tuple)) or len(inputs) != 1: + return False + cat_node.replace_all_uses_with(inputs[0]) + logger.debug("Eliminated single-input cat: %s", cat_node.name) + return True + + # ------------------------------------------------------------------ + # Patterns 2 & 4: cat -> slice (exact input or sub-range of input) + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_cat_then_slice( + cat_node: torch.fx.Node, + ) -> bool: + cat_inputs = cat_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: + return False + + cat_dim = cat_node.args[1] if len(cat_node.args) > 1 else 0 + output_rank = len(get_first_fake_tensor(cat_node).shape) + cat_dim = (cat_dim + output_rank) % output_rank + + users = list(cat_node.users.keys()) + if not users: + return False + + # Every user must be a slice_copy on the same dim with step=1. + for user in users: + if user.target != FuseConcatPass.slice_op: + return False + if user.args[0] is not cat_node: + return False + slice_dim = user.args[1] if len(user.args) > 1 else 0 + slice_dim = (slice_dim + output_rank) % output_rank + if slice_dim != cat_dim: + return False + slice_step = user.args[4] if len(user.args) > 4 else 1 + if slice_step != 1: + return False + + # Build the offset map for each concat input along cat_dim. + offsets = [] + offset = 0 + for inp in cat_inputs: + inp_shape = get_first_fake_tensor(inp).shape + size = inp_shape[cat_dim] + offsets.append((offset, offset + size, inp)) + offset += size + + # For each user, try exact match (Pattern 2) then sub-range (Pattern 4). + # Users that cross input boundaries are skipped. + replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = [] + + for user in users: + s_start = user.args[2] if len(user.args) > 2 else 0 + s_end = user.args[3] if len(user.args) > 3 else offset + s_end = min(s_end, offset) + + for o_start, o_end, inp in offsets: + if s_start == o_start and s_end == o_end: + # Pattern 2: exact match — replace slice with input. + replacements.append((user, inp)) + break + if s_start >= o_start and s_end <= o_end: + # Pattern 4: sub-range — replace with slice on original. + adj_start = s_start - o_start + adj_end = s_end - o_start + graph = cat_node.graph + with graph.inserting_before(user): + new_slice = graph.call_function( + FuseConcatPass.slice_op, + (inp, cat_dim, adj_start, adj_end), + ) + new_slice.meta = user.meta.copy() + replacements.append((user, new_slice)) + break + + if not replacements: + return False + + for old_node, new_node in replacements: + old_node.replace_all_uses_with(new_node) + + logger.debug( + "Eliminated cat-then-slice pattern: %s (%d slices redirected)", + cat_node.name, + len(replacements), + ) + return True + + # ------------------------------------------------------------------ + # Patterns 3 & 5: slice -> cat (contiguous slices, full or partial) + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_slice_then_cat( + cat_node: torch.fx.Node, + ) -> bool: + cat_inputs = cat_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: + return False + + cat_dim = cat_node.args[1] if len(cat_node.args) > 1 else 0 + output_rank = len(get_first_fake_tensor(cat_node).shape) + cat_dim = (cat_dim + output_rank) % output_rank + + # All inputs must be slice_copy on the same source tensor and dim, + # with step=1. + source_node = None + for inp in cat_inputs: + if not isinstance(inp, torch.fx.Node): + return False + if inp.target != FuseConcatPass.slice_op: + return False + slice_source = inp.args[0] + slice_dim = inp.args[1] if len(inp.args) > 1 else 0 + inp_rank = len(get_first_fake_tensor(inp).shape) + slice_dim = (slice_dim + inp_rank) % inp_rank + if slice_dim != cat_dim: + return False + slice_step = inp.args[4] if len(inp.args) > 4 else 1 + if slice_step != 1: + return False + if source_node is None: + source_node = slice_source + elif slice_source is not source_node: + return False + + if source_node is None: + return False + + source_shape = get_first_fake_tensor(source_node).shape + source_dim_size = source_shape[cat_dim] + + # Verify slices are contiguous (but not necessarily starting at 0). + first_inp = cat_inputs[0] + first_start = first_inp.args[2] if len(first_inp.args) > 2 else 0 + expected_start = first_start + for inp in cat_inputs: + s_start = inp.args[2] if len(inp.args) > 2 else 0 + s_end = inp.args[3] if len(inp.args) > 3 else source_dim_size + s_end = min(s_end, source_dim_size) + if s_start != expected_start: + return False + expected_start = s_end + last_end = expected_start + + # Verify output shape matches expectations. + cat_shape = get_first_fake_tensor(cat_node).shape + + if first_start == 0 and last_end == source_dim_size: + # Pattern 3: full coverage — replace with source tensor. + if list(cat_shape) != list(source_shape): + return False + cat_node.replace_all_uses_with(source_node) + logger.debug( + "Eliminated slice-then-cat (full): %s -> %s", + cat_node.name, + source_node.name, + ) + else: + # Pattern 5: partial coverage — replace with single slice. + expected_dim_size = last_end - first_start + if cat_shape[cat_dim] != expected_dim_size: + return False + for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)): + if i != cat_dim and cs != ss: # dims must match except for cat_dim + return False + graph = cat_node.graph + with graph.inserting_before(cat_node): + new_slice = graph.call_function( + FuseConcatPass.slice_op, + (source_node, cat_dim, first_start, last_end), + ) + new_slice.meta = cat_node.meta.copy() + cat_node.replace_all_uses_with(new_slice) + logger.debug( + "Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)", + cat_node.name, + source_node.name, + cat_dim, + first_start, + last_end, + ) + return True diff --git a/backends/arm/test/passes/test_fuse_concat_pass.py b/backends/arm/test/passes/test_fuse_concat_pass.py new file mode 100644 index 00000000000..1f8f77c17a5 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_concat_pass.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.fuse_concat_pass import FuseConcatPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +cat_op = "executorch_exir_dialects_edge__ops_aten_cat_default" +slice_op = "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor" + + +class SingleInputCat(torch.nn.Module): + """Pattern 1: cat with a single input is a no-op.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cat([x], dim=0) + + data = (torch.randn(2, 3, 4),) + ops_before_pass = {cat_op: 1} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op] + + +class CatThenSlice(torch.nn.Module): + """Pattern 2: cat followed by slices that extract exactly the inputs.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + # Extract exactly a and b back out + part_a = combined[:, :3, :] + part_b = combined[:, 3:, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class SliceThenCat(torch.nn.Module): + """Pattern 3: contiguous slices of the same tensor concatenated back.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = x[:, 3:, :] + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 8, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class CatNotEliminated(torch.nn.Module): + """Negative test: cat of different tensors should NOT be eliminated.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class SliceThenCatPartial(torch.nn.Module): + """Negative test: non-contiguous slices should NOT be eliminated.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = x[:, 4:, :] # Gap at index 3 + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 8, 4),) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceMismatch(torch.nn.Module): + """Negative test: slices that don't match original inputs.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) + return combined[:, 1:5, :] # Crosses the boundary + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceWithStep(torch.nn.Module): + """Negative test: slices with step != 1 should NOT be eliminated.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + part_a = combined[:, :3:2, :] # step=2, output shape differs from a + part_b = combined[:, 3::1, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSubSlice(torch.nn.Module): + """Pattern 4: slice extracts a sub-range within one concat input.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) # a dim1=6, b dim1=4 + # Range [1,5) falls entirely within a's range [0,6) + return combined[:, 1:5, :] + 1 + + data = (torch.randn(1, 6, 4), torch.randn(1, 4, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class CatThenSubSliceSecondInput(torch.nn.Module): + """Pattern 4: sub-slice within second concat input (tests offset adjust).""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) # a dim1=3, b dim1=8 + # Range [5,9) falls within b's range [3,11), adjusted to [2,6) on b + return combined[:, 5:9, :] + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 8, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class SliceThenCatPartialContiguous(torch.nn.Module): + """Pattern 5: contiguous slices covering a sub-range of the dimension.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, 2:5, :] + b = x[:, 5:8, :] + return torch.cat([a, b], dim=1) # Equivalent to x[:, 2:8, :] + + data = (torch.randn(1, 10, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +positive_tests = { + "single_input_cat": SingleInputCat(), + "cat_then_slice": CatThenSlice(), + "slice_then_cat": SliceThenCat(), + "cat_then_sub_slice": CatThenSubSlice(), + "cat_then_sub_slice_second_input": CatThenSubSliceSecondInput(), + "slice_then_cat_partial_contiguous": SliceThenCatPartialContiguous(), +} + +negative_tests = { + "cat_not_eliminated": CatNotEliminated(), + "slice_then_cat_partial": SliceThenCatPartial(), + "cat_then_slice_mismatch": CatThenSliceMismatch(), + "cat_then_slice_with_step": CatThenSliceWithStep(), +} + + +@common.parametrize("model", positive_tests) +def test_fuse_concat_eliminates(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + ops_not_after_pass=getattr(model, "ops_not_after_pass", []), + pass_list=[FuseConcatPass], + ) + pipeline.run() + + +@common.parametrize("model", negative_tests) +def test_fuse_concat_preserves(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + pass_list=[FuseConcatPass], + ) + pipeline.run()