Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def call(self, graph_module: torch.fx.GraphModule):
Entry point for the pass: annotate spatial ranks, compute dim orders,
insert bridging transposes, and forward to child passes.
"""
graph_module.graph.eliminate_dead_code()
nodes = list(graph_module.graph.nodes)
for node in nodes:
if not self._is_ok_for_annotation(node):
Expand Down
86 changes: 86 additions & 0 deletions backends/arm/test/misc/test_const_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,22 @@

from typing import Set, Type

import executorch.backends.arm.tosa.dialect # noqa: F401
import pytest
import torch
import tosa_serializer as ts
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.to_tosa_memory_format_pass import (
ToTosaMemoryFormatPass,
)
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.process_node import process_call_function
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.backends.test.graph_builder import GraphBuilder
from executorch.exir import to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -54,3 +67,76 @@ def forward(self, x):
assert const_shape_nodes
for n in const_shape_nodes:
assert n.meta[TosaSpecialDtype.meta_key()] == TosaSpecialDtype.SHAPE


def _graph_module_with_unused_const_shape():
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
builder = GraphBuilder()
builder.call_operator(exir_ops.backend.tosa.CONST_SHAPE.default, ([1],))
live_const = builder.call_operator(
exir_ops.backend.tosa.CONST_SHAPE.default, ([3],)
)
builder.output([live_const])
graph_module = ExportPass().call(builder.get_graph_module()).graph_module
for node in graph_module.graph.nodes:
if node.op == "call_function":
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
return graph_module


def _propagate_shape_dim_orders_from_users(graph_module: torch.fx.GraphModule) -> None:
output_node = next(node for node in graph_module.graph.nodes if node.op == "output")
output_node.meta["tosa_dim_order"] = (0,)
dummy_exported = torch.export.export(torch.nn.Identity(), (torch.randn(1),))
tosa_memory_format_pass = ToTosaMemoryFormatPass(dummy_exported)
tosa_memory_format_pass._propagate_dim_order_to_shape_args(output_node)


def _serialize_graph_module_to_tosa(graph_module: torch.fx.GraphModule):
tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+shape")
node_visitors = get_node_visitors(None, tosa_spec)
tosa_graph = ts.TosaSerializer(
"",
targetMajor=tosa_spec.version.major,
targetMinor=tosa_spec.version.minor,
targetPatch=tosa_spec.version.micro,
targetDraft=True,
)

for node in graph_module.graph.nodes:
if node.op == "call_function":
process_call_function(node, tosa_graph, node_visitors, tosa_spec)

return tosa_graph


def test_unused_shape_ops_miss_tosa_dim_order_and_must_be_removed_before_tosa_serialization():
graph_module = _graph_module_with_unused_const_shape()
_propagate_shape_dim_orders_from_users(graph_module)

const_shape_nodes = [
node
for node in graph_module.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.backend.tosa.CONST_SHAPE.default
]
dead_const_shape, live_const_shape = const_shape_nodes

assert dead_const_shape.users == {}
assert "tosa_dim_order" not in dead_const_shape.meta
assert live_const_shape.meta["tosa_dim_order"] == (0,)

with pytest.raises(KeyError, match="tosa_dim_order"):
_serialize_graph_module_to_tosa(graph_module)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()

remaining_const_shape = next(
node
for node in graph_module.graph.nodes
if node.op == "call_function"
and node.target == exir_ops.backend.tosa.CONST_SHAPE.default
)
assert remaining_const_shape.meta["tosa_dim_order"] == (0,)
assert _serialize_graph_module_to_tosa(graph_module)
Loading