diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 83e4e9bad37..1079e43d3b9 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -24,6 +24,7 @@ from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_maxpool3d import DecomposeMaxPool3d from .decompose_minmaxdim import DecomposeMinMaxDim +from .decompose_reciprocal import DecomposeReciprocal from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu from .decompose_threshold import DecomposeThreshold @@ -72,6 +73,7 @@ DecomposeLinalgVectorNorm, DecomposeMaxPool3d, DecomposeMinMaxDim, + DecomposeReciprocal, DecomposeRoll, DecomposeSilu, DecomposeThreshold, diff --git a/backends/qualcomm/_passes/decompose_reciprocal.py b/backends/qualcomm/_passes/decompose_reciprocal.py new file mode 100644 index 00000000000..405d99d1171 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_reciprocal.py @@ -0,0 +1,39 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta + + +class DecomposeReciprocal(ExportPass): + def __init__(self): + super(DecomposeReciprocal, self).__init__() + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in graph.nodes: + if node.target in { + torch.ops.aten.reciprocal.default, + }: + reciprocal_node = node + reciprocal_node_input = node.args[0] + with graph_module.graph.inserting_after(reciprocal_node_input): + # Create division node + div_node = graph.call_function( + torch.ops.aten.div.Tensor, + (1, reciprocal_node_input), + ) + div_node.meta = copy_meta(reciprocal_node.meta) + + # Replace all uses of reciprocal with division + for user in reciprocal_node.users.copy(): + user.replace_input_with(reciprocal_node, div_node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 5f4168c1770..921cff57b59 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -29,6 +29,7 @@ DecomposeLinalgVectorNorm, DecomposeMaxPool3d, DecomposeMinMaxDim, + DecomposeReciprocal, DecomposeRoll, DecomposeSilu, DecomposeThreshold, @@ -215,6 +216,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) self.add_pass(DecomposeGlu()) + # HTP and GPU doesn't support ElementWiseUnary with operation=reciprocal + # Decompose Reciprocal into Div for these 2 backend + # TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager) + self.add_pass(DecomposeReciprocal()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(ReplaceInfValues()) self.add_pass(LiftConstantScalarOperands()) @@ -238,6 +243,10 @@ def transform_for_export_pipeline( # This pass is needed before to_edge pipeline to avoid mixed type for div operator with RemoveMixedTypeOperators pass. self.add_pass(DecomposeFloorDivide()) self.add_pass(DecomposeWrapWithAutocast()) + # HTP and GPU doesn't support ElementWiseUnary with operation=reciprocal + # Decompose Reciprocal into Div for these 2 backend + # TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager) + self.add_pass(DecomposeReciprocal()) # this pass will rewrite state_dict, it needs to be accomplished before # to_edge_transform_and_lower self.add_pass(CanonicalizeConv(exported_program)) diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index a0643bd4f1d..bca15cba0fa 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -422,7 +422,7 @@ Please help update following table if you are contributing new operators: | ElementWiseSquaredDifference | ✗ | | ElementWiseSquareRoot | ✓ | | ElementWiseSubtract | ✓ | -| ElementWiseUnary | ✗ | +| ElementWiseUnary | ✓ | | ElementWiseXor | ✓ | | Elu | ✓ | | ExpandDims | ✓ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index e982985477d..efe9434ff0b 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -104,6 +104,7 @@ op_to, op_topk, op_transpose, + op_unary, op_unbind, op_unsqueeze, op_upsample_bilinear2d, @@ -185,6 +186,7 @@ op_pow, op_prelu, op_quantize, + op_unary, op_relu, op_repeat, op_reshape, diff --git a/backends/qualcomm/builders/op_unary.py b/backends/qualcomm/builders/op_unary.py new file mode 100644 index 00000000000..b2814933679 --- /dev/null +++ b/backends/qualcomm/builders/op_unary.py @@ -0,0 +1,66 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpElementWiseUnary, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Unary(NodeVisitor): + target = ["aten.reciprocal.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + reciprocal_inp_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + reciprocal_input_tensors = [reciprocal_inp_tensor_wrapper] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + reciprocal_output_tensors = [output_tensor_wrapper] + + reciprocal_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseUnary.op_name, + ) + reciprocal_op.AddInputTensors(reciprocal_input_tensors) + reciprocal_op.AddOutputTensors(reciprocal_output_tensors) + + reciprocal_op.AddScalarParam( + OpElementWiseUnary.param_operation, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(OpElementWiseUnary.Operation.RECIPROCAL)}, + ) + + return reciprocal_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index a27f793d923..420feb202d9 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -280,6 +280,31 @@ class OpElementWiseSubtract: op_name = "ElementWiseSubtract" +@dataclass(init=False, frozen=True) +class OpElementWiseUnary: + op_name: str = "ElementWiseUnary" + param_operation: str = "operation" + + @unique + class Operation(IntEnum): + ABS = 0 + ASIN = 1 + ATAN = 2 + CEIL = 3 + COS = 4 + EXP = 5 + FLOOR = 6 + LOG = 7 + NEG = 8 + NOT = 9 + RECIPROCAL = 10 + ROUND = 11 + RSQRT = 12 + SIGN = 13 + SIN = 14 + SQRT = 15 + + @dataclass(init=False, frozen=True) class OpElementWiseXor: op_name: str = "ElementWiseXor" diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 2089fea26bf..7fec9f63688 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1765,6 +1765,14 @@ def forward(self, x): return self.prelu(x) +class Reciprocal(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.reciprocal(x) + + class Relu(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index dd164593134..b738f784d1e 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1606,6 +1606,11 @@ def test_qnn_backend_prelu(self): index += 1 self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reciprocal(self): + module = Reciprocal() # noqa: F405 + sample_input = (torch.randn([2, 2, 2, 2]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_relu(self): module = Relu() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),) @@ -3930,6 +3935,12 @@ def test_qnn_backend_prelu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reciprocal(self): + module = Reciprocal() # noqa: F405 + sample_input = (torch.randn([2, 5, 1, 3]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_relu(self): module = Relu() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),)