diff --git a/backends/arm/ethosu/partitioner.py b/backends/arm/ethosu/partitioner.py index 7b2c561fae2..4ce5911ec1f 100644 --- a/backends/arm/ethosu/partitioner.py +++ b/backends/arm/ethosu/partitioner.py @@ -3,9 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from typing import final, Optional, Sequence +import torch from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec from executorch.backends.arm.tosa.partitioner import TOSAPartitioner from executorch.exir.backend.partitioner import DelegationSpec @@ -33,3 +33,4 @@ def __init__( ) self.additional_checks = additional_checks self.tosa_spec = compile_spec.tosa_spec + self._custom_partition_ops: set[torch._ops.OpOverload] = set() diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index c7f9da2ccd4..2e46f9c5f76 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -51,6 +51,7 @@ op_to_dim_order_copy, op_tosa_conv2d, op_tosa_conv3d, + op_tosa_custom, op_tosa_depthwise_conv2d, op_tosa_gather, op_tosa_matmul, diff --git a/backends/arm/operators/op_tosa_custom.py b/backends/arm/operators/op_tosa_custom.py new file mode 100644 index 00000000000..82e7c5cffd8 --- /dev/null +++ b/backends/arm/operators/op_tosa_custom.py @@ -0,0 +1,85 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class CustomVisitor(NodeVisitor): + """Lower the TOSA CUSTOM op from the TOSA backend dialect.""" + + target = "tosa.CUSTOM.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + allowed_kwargs = {"operator_name", "domain_name", "implementation_attrs"} + unexpected = set(node.kwargs.keys()) - allowed_kwargs + if unexpected: + raise ValueError( + f"tosa.CUSTOM received unexpected kwargs: {sorted(unexpected)}" + ) + + operator_name = node.kwargs.get("operator_name") + domain_name = node.kwargs.get("domain_name") + implementation_attrs = node.kwargs.get("implementation_attrs") + + if operator_name is None or domain_name is None: + raise ValueError( + "tosa.CUSTOM requires operator_name and domain_name in kwargs" + ) + + if implementation_attrs is None: + impl_list = [] + elif isinstance(implementation_attrs, list): + # NOTE: PyTorch schemas do not support a bytes type; we pass + # implementation_attrs as int[] representing raw bytes. + impl_list = [int(x) for x in implementation_attrs] + else: + raise TypeError( + "implementation_attrs must be None or list[int]; " + f"got {type(implementation_attrs)}" + ) + + attr = ts.TosaSerializerAttribute() + attr.CustomAttribute( + operator_name=operator_name, + domain_name=domain_name, + implementation_attrs=impl_list, + ) + + expanded = [TosaArg(item, self.tosa_spec) for item in inputs[0].special] + input_names = [arg.name for arg in expanded] + output_names = ( + output.multiple_output_names + if getattr(output, "multiple_output_names", None) + else [output.name] + ) + if len(output_names) != 1: + # TODO: Support multi-output CUSTOM ops with per-output meta/shape. + raise ValueError( + f"tosa.CUSTOM currently requires a single output, got {len(output_names)}" + ) + self._serialize_operator( + node, + tosa_graph, + ts.Op.CUSTOM, + input_names, + output_names, + attr, + ) diff --git a/backends/arm/public_api_manifests/api_manifest_running.toml b/backends/arm/public_api_manifests/api_manifest_running.toml index 44de795799e..e19b9e6d58c 100644 --- a/backends/arm/public_api_manifests/api_manifest_running.toml +++ b/backends/arm/public_api_manifests/api_manifest_running.toml @@ -56,6 +56,10 @@ signature = "EthosUPartitioner.ops_to_not_decompose(self, ep: torch.export.expor kind = "function" signature = "EthosUPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult" +[python.EthosUPartitioner.register_custom_partition_op] +kind = "function" +signature = "EthosUPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None" + [python.EthosUQuantizer] kind = "class" signature = "EthosUQuantizer(compile_spec: 'EthosUCompileSpec', use_composable_quantizer: 'bool' = False) -> 'None'" @@ -136,6 +140,10 @@ signature = "VgfPartitioner.ops_to_not_decompose(self, ep: torch.export.exported kind = "function" signature = "VgfPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult" +[python.VgfPartitioner.register_custom_partition_op] +kind = "function" +signature = "VgfPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None" + [python.VgfQuantizer] kind = "class" signature = "VgfQuantizer(compile_spec: 'VgfCompileSpec', use_composable_quantizer: 'bool' = False) -> 'None'" diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index dffd55cc52a..1a66188cd58 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -6,6 +6,7 @@ from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401 conv2d, conv3d, + custom, depthwise_conv2d, gather, matmul, diff --git a/backends/arm/tosa/dialect/ops/custom.py b/backends/arm/tosa/dialect/ops/custom.py new file mode 100644 index 00000000000..6376124d6f2 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/custom.py @@ -0,0 +1,159 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Fake-op support for the generic TOSA ``CUSTOM`` dialect op. + +The serialized TOSA ``CUSTOM`` op is intentionally generic: it carries a +stable operator identity (for example ``myns.my_op``) plus an +opaque payload in ``implementation_attrs``. That is enough for serialization, +but not enough for FakeTensor propagation unless we also teach the compiler how +to model the output tensors of the specific wrapped op. + +This module provides a lightweight registration mechanism for those compiler +side fake implementations: + +1. A lowering pass rewrites an op to ``exir_ops.backend.tosa.CUSTOM.default``. +2. The wrapped custom op registers a thin adapter with + ``@register_fake_tosa("namespace::op")``. +3. The generic ``CUSTOM`` fake implementation looks up that adapter by the + ``operator_name`` argument and invokes it with the full custom-op calling + convention ``(inputs, operator_name, domain_name, implementation_attrs)``. + +The adapter should stay thin: it should only translate from the generic TOSA +CUSTOM signature back to the wrapped op's fake semantics. The real semantic +logic should continue to live in the original fake implementation where +possible. + +""" + +import inspect +from collections.abc import Callable + +import torch +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) + +_TOSA_CUSTOM_FAKE_IMPLS: dict[str, Callable] = {} + + +def _normalize_tosa_custom_operator_name(operator_name: str) -> str: + """Normalize operator names so ``ns::op`` and ``ns.op`` map identically.""" + return operator_name.replace("::", ".") + + +def validate_tosa_custom_fake_impl(fake_impl: object) -> Callable: + """Validate the signature expected by ``register_fake_tosa``. + + Registered fake implementations must accept the generic TOSA CUSTOM fake + calling convention: + + ``(inputs, operator_name, domain_name, implementation_attrs)`` + + and return ``list[Tensor]``. + + """ + if not callable(fake_impl): + raise TypeError( + "Expected tosa.CUSTOM fake impl to be callable, " f"got {type(fake_impl)}" + ) + + params = tuple(inspect.signature(fake_impl).parameters.values()) + positional_kinds = { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + } + if len(params) != 4 or any(param.kind not in positional_kinds for param in params): + raise TypeError( + "tosa.CUSTOM fake impl must have signature " + "(inputs, operator_name, domain_name, implementation_attrs)" + ) + return fake_impl + + +def register_fake_tosa(operator_name: str) -> Callable[[Callable], Callable]: + """Register a fake implementation for a specific wrapped TOSA custom op. + + Args: + operator_name: Stable custom operator identifier. Both ``ns::op`` and + ``ns.op`` spellings are accepted. + + Returns: + A decorator that registers a callable with signature + ``(inputs, operator_name, domain_name, implementation_attrs)`` and + returning ``list[Tensor]``. + + Example: + ``@register_fake_tosa("my_namespace::my_op")`` + + """ + normalized_name = _normalize_tosa_custom_operator_name(operator_name) + + def decorator(fake_impl: Callable) -> Callable: + validated = validate_tosa_custom_fake_impl(fake_impl) + _TOSA_CUSTOM_FAKE_IMPLS[normalized_name] = validated + return fake_impl + + return decorator + + +def has_fake_tosa_impl(operator_name: str) -> bool: + """Return whether a wrapped custom op has a registered fake impl.""" + normalized_name = _normalize_tosa_custom_operator_name(operator_name) + return normalized_name in _TOSA_CUSTOM_FAKE_IMPLS + + +def run_registered_fake_tosa_impl( + inputs: list[torch.Tensor], + operator_name: str, + domain_name: str, + implementation_attrs: list[int], +) -> list[torch.Tensor]: + """Invoke the registered fake implementation for a wrapped custom op.""" + normalized_name = _normalize_tosa_custom_operator_name(operator_name) + fake_impl = _TOSA_CUSTOM_FAKE_IMPLS.get(normalized_name) + if fake_impl is None: + raise RuntimeError( + f"tosa.CUSTOM requires a registered fake impl for {normalized_name}" + ) + outputs = fake_impl(inputs, operator_name, domain_name, implementation_attrs) + if not isinstance(outputs, list): + raise TypeError( + "tosa.CUSTOM fake impl must return list[Tensor], " f"got {type(outputs)}" + ) + if not outputs: + raise RuntimeError("tosa.CUSTOM fake impl must return at least one output") + if not all(isinstance(output, torch.Tensor) for output in outputs): + raise TypeError("tosa.CUSTOM fake impl must return list[Tensor]") + return outputs + + +@register_fake_tosa_op( + "CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]", + TosaSpecification.all_versions_and_profiles(), +) +def CUSTOM( + inputs: list[torch.Tensor], + operator_name: str, + domain_name: str, + implementation_attrs: list[int], +) -> list[torch.Tensor]: + """Fake implementation for TOSA CUSTOM op. + + The CUSTOM op is backend-defined. The fake implementation dispatches to a + registered compiler-side fake implementation for the specific custom op. + + """ + _ = get_context_spec() # ensure a spec context exists + if not inputs: + raise RuntimeError("tosa.CUSTOM requires at least one input tensor") + return run_registered_fake_tosa_impl( + inputs, + operator_name, + domain_name, + implementation_attrs, + ) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index b0e8aee8869..2833da4c9ad 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -139,6 +139,11 @@ def extract_tensor_meta(meta): if type(val) is tuple: # TODO: should use first concrete representation val = val[0] + if isinstance(val, list): + if not val: + raise ValueError("Expected node.meta['val'] list to be non-empty") + # Use first concrete representation for multi-output ops. + val = val[0] if not isinstance(val, torch._subclasses.fake_tensor.FakeTensor): raise ValueError( diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 957e8f0e5d4..a7ef79abbef 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -43,11 +43,24 @@ from torch.export.exported_program import ExportedProgram from torch.fx import GraphModule from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition -from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.operator_support import any_chain, OperatorSupportBase logger = logging.getLogger(__name__) +def _is_custom_partition_op( + custom_ops: set[torch._ops.OpOverload], target: object +) -> bool: + if target in custom_ops: + return True + if hasattr(target, "_op"): + try: + return target._op in custom_ops + except Exception: + return False + return False + + def _is_noop_clone(node: torch.fx.node.Node) -> bool: return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default @@ -149,6 +162,13 @@ def __init__( ) self.tosa_spec = compile_spec.tosa_spec self.additional_checks = additional_checks + self._custom_partition_ops: set[torch._ops.OpOverload] = set() + + def register_custom_partition_op(self, op: torch._ops.OpOverload) -> None: + """Register a custom op to be considered supported by this + partitioner. + """ + self._custom_partition_ops.add(op) def _detag_boundary_nodes( self, module: GraphModule, tag: str, reporter: WhyNoPartitionReporter @@ -233,6 +253,16 @@ def _tag_module( # noqa operator_support = tosa_support_factory( self.tosa_spec, containing_program, reporter, self.additional_checks ) + if self._custom_partition_ops: + custom_ops = set(self._custom_partition_ops) + + class CustomOpSupported(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and _is_custom_partition_op( + custom_ops, node.target + ) + + operator_support = any_chain(operator_support, CustomOpSupported()) capability_partitioner = CapabilityBasedPartitioner( module, operator_support, @@ -368,6 +398,8 @@ def filter_fn(node: torch.fx.Node) -> bool: bool: True to keep the op intact; otherwise, False. """ + if _is_custom_partition_op(self._custom_partition_ops, node.target): + return True if ( self.tosa_spec.support_float() and node.target in ops_to_not_decompose_if_fp @@ -444,6 +476,7 @@ def filter_fn(node: torch.fx.Node) -> bool: | ops_to_not_decompose_if_fp | ops_to_not_decompose_if_integer ) + ops_to_not_decompose.extend(self._custom_partition_ops) if not self.tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d diff --git a/backends/arm/vgf/partitioner.py b/backends/arm/vgf/partitioner.py index dfc4056922f..eea49e8bc03 100644 --- a/backends/arm/vgf/partitioner.py +++ b/backends/arm/vgf/partitioner.py @@ -3,9 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from typing import final, Optional, Sequence +import torch from executorch.backends.arm.tosa.partitioner import TOSAPartitioner from executorch.backends.arm.vgf import VgfBackend, VgfCompileSpec from executorch.exir.backend.partitioner import DelegationSpec @@ -33,3 +33,4 @@ def __init__( ) self.additional_checks = additional_checks self.tosa_spec = compile_spec.tosa_spec + self._custom_partition_ops: set[torch._ops.OpOverload] = set()