Skip to content

Commit 9aeea5f

Browse files
authored
Merge branch 'main' into per-weight-constant-cache
2 parents 865f118 + 87e65ac commit 9aeea5f

File tree

20 files changed

+2836
-41
lines changed

20 files changed

+2836
-41
lines changed

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .decompose_glu_pass import DecomposeGluPass # noqa
5454
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
5555
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
56+
from .decompose_gru_pass import DecomposeGruPass # noqa
5657
from .decompose_index_copy_pass import DecomposeIndexCopyPass # noqa
5758
from .decompose_index_select_to_gather_pass import ( # noqa
5859
DecomposeIndexSelectToGatherPass,
@@ -70,13 +71,15 @@
7071
from .decompose_linear_pass import DecomposeLinearPass # noqa
7172
from .decompose_log1p_pass import DecomposeLog1pPass # noqa
7273
from .decompose_logit_pass import DecomposeLogitPass # noqa
74+
from .decompose_lstm_pass import DecomposeLstmPass # noqa
7375
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
7476
from .decompose_matmul import DecomposeMatmulPass # noqa
7577
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
7678
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
7779
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
7880
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
7981
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
82+
from .decompose_rnn_pass import DecomposeRnnPass # noqa
8083
from .decompose_round_pass import DecomposeRoundPass # noqa
8184
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
8285
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
DecomposeGluPass,
6262
DecomposeGroupedConvPass,
6363
DecomposeGroupNormPass,
64+
DecomposeGruPass,
6465
DecomposeIndexCopyPass,
6566
DecomposeIndexSelectToGatherPass,
6667
DecomposeIndexTensorToGatherPass,
@@ -71,13 +72,15 @@
7172
DecomposeLinearPass,
7273
DecomposeLog1pPass,
7374
DecomposeLogitPass,
75+
DecomposeLstmPass,
7476
DecomposeMaskedFillPass,
7577
DecomposeMatmulPass,
7678
DecomposeMaxPool2dPass,
7779
DecomposeMeanDimPass,
7880
DecomposeNotEqualPass,
7981
DecomposeQuantNodesPass,
8082
DecomposeRemainderPass,
83+
DecomposeRnnPass,
8184
DecomposeRoundPass,
8285
DecomposeScaledDotProductAttentionPass,
8386
DecomposeSelectPass,
@@ -360,6 +363,9 @@ def _tosa_pipeline(
360363
ConvertToClampPass(),
361364
DecomposeTOSAUnsupportedClampPass(),
362365
DecomposeGroupNormPass(),
366+
DecomposeGruPass(),
367+
DecomposeLstmPass(),
368+
DecomposeRnnPass(),
363369
DecomposeLayerNormPass(),
364370
DecomposeVarPass(),
365371
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
@@ -578,6 +584,9 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
578584
self.add_passes(
579585
[
580586
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
587+
DecomposeGruPass(tfa_pass=True),
588+
DecomposeLstmPass(tfa_pass=True),
589+
DecomposeRnnPass(tfa_pass=True),
581590
DecomposeNotEqualPass(tfa_pass=True),
582591
DecomposeCosineSimilarityPass(tfa_pass=True),
583592
DecomposeGluPass(tfa_pass=True),

0 commit comments

Comments
 (0)