From d71c6153df81fc608a7eb78515be148ed882fd10 Mon Sep 17 00:00:00 2001 From: boyuc Date: Sat, 24 Jan 2026 20:18:52 +0800 Subject: [PATCH 1/7] Add op_builder for reciprocal op - Add constant and enums for QNN ElementwiseUnary op - Add op_builder for reciprocal op - Add pass decompose reciprocal to div because HTP and GPU backend doesn't support reciprocal operation yet --- backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/op_reciprocal.py | 66 ++++++++++++++++++++ backends/qualcomm/builders/qnn_constants.py | 24 +++++++ backends/qualcomm/tests/models.py | 8 +++ backends/qualcomm/tests/test_qnn_delegate.py | 5 ++ 5 files changed, 105 insertions(+) create mode 100644 backends/qualcomm/builders/op_reciprocal.py diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index e982985477d..6b8f95c211c 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -77,6 +77,7 @@ op_pow, op_prelu, op_quantize, + op_reciprocal, op_relu, op_repeat, op_reshape, @@ -185,6 +186,7 @@ op_pow, op_prelu, op_quantize, + op_reciprocal, op_relu, op_repeat, op_reshape, diff --git a/backends/qualcomm/builders/op_reciprocal.py b/backends/qualcomm/builders/op_reciprocal.py new file mode 100644 index 00000000000..44d8bf146a4 --- /dev/null +++ b/backends/qualcomm/builders/op_reciprocal.py @@ -0,0 +1,66 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnWrapper +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +import numpy as np +import torch + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpElementWiseUnary, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Reciprocal(NodeVisitor): + target = ["aten.reciprocal.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + reciprocal_inp_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + reciprocal_input_tensors = [reciprocal_inp_tensor_wrapper] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + reciprocal_output_tensors = [output_tensor_wrapper] + + reciprocal_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseUnary.op_name, + ) + reciprocal_op.AddInputTensors(reciprocal_input_tensors) + reciprocal_op.AddOutputTensors(reciprocal_output_tensors) + + reciprocal_op.AddScalarParam( + OpElementWiseUnary.param_operation, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(OpElementWiseUnary.Operation.RECIPROCAL)} + ) + + return reciprocal_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index a27f793d923..78f1ef7dc5b 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -280,6 +280,30 @@ class OpElementWiseSubtract: op_name = "ElementWiseSubtract" +@dataclass(init=False, frozen=True) +class OpElementWiseUnary: + op_name: str = "ElementWiseUnary" + param_operation: str = "operation" + @unique + class Operation(IntEnum): + ABS = 0 + ASIN = 1 + ATAN = 2 + CEIL = 3 + COS = 4 + EXP = 5 + FLOOR = 6 + LOG = 7 + NEG = 8 + NOT = 9 + RECIPROCAL = 10 + ROUND = 11 + RSQRT = 12 + SIGN = 13 + SIN = 14 + SQRT = 15 + + @dataclass(init=False, frozen=True) class OpElementWiseXor: op_name: str = "ElementWiseXor" diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 2089fea26bf..7fec9f63688 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1765,6 +1765,14 @@ def forward(self, x): return self.prelu(x) +class Reciprocal(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.reciprocal(x) + + class Relu(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index dd164593134..31df526007a 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1606,6 +1606,11 @@ def test_qnn_backend_prelu(self): index += 1 self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reciprocal(self): + module = Reciprocal() # noqa: F405 + sample_input = (torch.randn([2, 2, 2, 2]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_relu(self): module = Relu() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),) From a3f262aa7d19912850e7d65a6e2b4c4dcc919c8e Mon Sep 17 00:00:00 2001 From: boyuc Date: Sat, 24 Jan 2026 23:34:50 +0800 Subject: [PATCH 2/7] Add decomposition for reciprocal op - Add reciprocal decomposition in export transform --- backends/qualcomm/_passes/__init__.py | 2 + .../qualcomm/_passes/decompose_reciprocal.py | 55 +++++++++++++++++++ backends/qualcomm/_passes/qnn_pass_manager.py | 7 +++ 3 files changed, 64 insertions(+) create mode 100644 backends/qualcomm/_passes/decompose_reciprocal.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 83e4e9bad37..1079e43d3b9 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -24,6 +24,7 @@ from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_maxpool3d import DecomposeMaxPool3d from .decompose_minmaxdim import DecomposeMinMaxDim +from .decompose_reciprocal import DecomposeReciprocal from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu from .decompose_threshold import DecomposeThreshold @@ -72,6 +73,7 @@ DecomposeLinalgVectorNorm, DecomposeMaxPool3d, DecomposeMinMaxDim, + DecomposeReciprocal, DecomposeRoll, DecomposeSilu, DecomposeThreshold, diff --git a/backends/qualcomm/_passes/decompose_reciprocal.py b/backends/qualcomm/_passes/decompose_reciprocal.py new file mode 100644 index 00000000000..e5dd736fea9 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_reciprocal.py @@ -0,0 +1,55 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult +from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix + +from .utils import copy_meta + +class DecomposeReciprocal(ExportPass): + def __init__(self): + super(DecomposeReciprocal, self).__init__() + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in graph.nodes: + if node.op == "call_function" and node.target in { + torch.ops.aten.reciprocal.default, + }: + reciprocal_node = node + reciprocal_node_input = node.args[0] + + # Create tensor of ones with same shape and dtype as input + fake_val = reciprocal_node_input.meta["val"] + ones_tensor = torch.ones(*fake_val.size(), dtype=fake_val.dtype) + + # Generate unique name and register buffer + buffer_name = get_new_attr_name_with_prefix("_ones_constant_")( + graph_module + ) + graph_module.register_buffer(buffer_name, ones_tensor) + + with graph_module.graph.inserting_after(reciprocal_node_input): + # Create get_attr node for the ones tensor + ones_node = graph.get_attr(buffer_name) + ones_node.meta = copy_meta(reciprocal_node.meta) + + with graph_module.graph.inserting_after(ones_node): + # Create division node: ones / input + div_node = graph.call_function( + torch.ops.aten.div.Tensor, + (ones_node, reciprocal_node_input), + ) + div_node.meta = copy_meta(reciprocal_node.meta) + + # Replace all uses of reciprocal with division + for user in reciprocal_node.users.copy(): + user.replace_input_with(reciprocal_node, div_node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 5f4168c1770..700a7b8ef36 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -29,6 +29,7 @@ DecomposeLinalgVectorNorm, DecomposeMaxPool3d, DecomposeMinMaxDim, + DecomposeReciprocal, DecomposeRoll, DecomposeSilu, DecomposeThreshold, @@ -238,6 +239,12 @@ def transform_for_export_pipeline( # This pass is needed before to_edge pipeline to avoid mixed type for div operator with RemoveMixedTypeOperators pass. self.add_pass(DecomposeFloorDivide()) self.add_pass(DecomposeWrapWithAutocast()) + + # HTP and GPU doesn't support ElementWiseUnary with operation=reciprocal + # Decompose Reciprocal into Div for these 2 backend + # TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager) + self.add_pass(DecomposeReciprocal()) + # this pass will rewrite state_dict, it needs to be accomplished before # to_edge_transform_and_lower self.add_pass(CanonicalizeConv(exported_program)) From 5980e51803ed1b8326d57e867b40ae411478dceb Mon Sep 17 00:00:00 2001 From: boyuc Date: Sun, 25 Jan 2026 13:30:39 +0800 Subject: [PATCH 3/7] Run linter --- backends/qualcomm/_passes/decompose_reciprocal.py | 1 + backends/qualcomm/builders/op_reciprocal.py | 4 ++-- backends/qualcomm/builders/qnn_constants.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/_passes/decompose_reciprocal.py b/backends/qualcomm/_passes/decompose_reciprocal.py index e5dd736fea9..4194677c079 100644 --- a/backends/qualcomm/_passes/decompose_reciprocal.py +++ b/backends/qualcomm/_passes/decompose_reciprocal.py @@ -10,6 +10,7 @@ from .utils import copy_meta + class DecomposeReciprocal(ExportPass): def __init__(self): super(DecomposeReciprocal, self).__init__() diff --git a/backends/qualcomm/builders/op_reciprocal.py b/backends/qualcomm/builders/op_reciprocal.py index 44d8bf146a4..909297a4272 100644 --- a/backends/qualcomm/builders/op_reciprocal.py +++ b/backends/qualcomm/builders/op_reciprocal.py @@ -6,10 +6,10 @@ from typing import Dict import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnWrapper -from executorch.backends.qualcomm.utils.constants import QCOM_DATA import numpy as np import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor @@ -60,7 +60,7 @@ def define_node( reciprocal_op.AddScalarParam( OpElementWiseUnary.param_operation, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - {QCOM_DATA: np.uint32(OpElementWiseUnary.Operation.RECIPROCAL)} + {QCOM_DATA: np.uint32(OpElementWiseUnary.Operation.RECIPROCAL)}, ) return reciprocal_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 78f1ef7dc5b..420feb202d9 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -284,6 +284,7 @@ class OpElementWiseSubtract: class OpElementWiseUnary: op_name: str = "ElementWiseUnary" param_operation: str = "operation" + @unique class Operation(IntEnum): ABS = 0 From d00f795b6903e369f86286bd5f80924d21be9e9c Mon Sep 17 00:00:00 2001 From: boyuc Date: Sun, 25 Jan 2026 14:09:20 +0800 Subject: [PATCH 4/7] Add reciprocal decomposition for Quantized Model --- backends/qualcomm/_passes/decompose_reciprocal.py | 1 + backends/qualcomm/_passes/qnn_pass_manager.py | 6 ++++-- backends/qualcomm/tests/test_qnn_delegate.py | 6 ++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/_passes/decompose_reciprocal.py b/backends/qualcomm/_passes/decompose_reciprocal.py index 4194677c079..b18ac167950 100644 --- a/backends/qualcomm/_passes/decompose_reciprocal.py +++ b/backends/qualcomm/_passes/decompose_reciprocal.py @@ -38,6 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule): # Create get_attr node for the ones tensor ones_node = graph.get_attr(buffer_name) ones_node.meta = copy_meta(reciprocal_node.meta) + ones_node.meta["val"] = reciprocal_node_input.meta["val"].clone() with graph_module.graph.inserting_after(ones_node): # Create division node: ones / input diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 700a7b8ef36..921cff57b59 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -216,6 +216,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) self.add_pass(DecomposeGlu()) + # HTP and GPU doesn't support ElementWiseUnary with operation=reciprocal + # Decompose Reciprocal into Div for these 2 backend + # TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager) + self.add_pass(DecomposeReciprocal()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(ReplaceInfValues()) self.add_pass(LiftConstantScalarOperands()) @@ -239,12 +243,10 @@ def transform_for_export_pipeline( # This pass is needed before to_edge pipeline to avoid mixed type for div operator with RemoveMixedTypeOperators pass. self.add_pass(DecomposeFloorDivide()) self.add_pass(DecomposeWrapWithAutocast()) - # HTP and GPU doesn't support ElementWiseUnary with operation=reciprocal # Decompose Reciprocal into Div for these 2 backend # TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager) self.add_pass(DecomposeReciprocal()) - # this pass will rewrite state_dict, it needs to be accomplished before # to_edge_transform_and_lower self.add_pass(CanonicalizeConv(exported_program)) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 31df526007a..b738f784d1e 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -3935,6 +3935,12 @@ def test_qnn_backend_prelu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reciprocal(self): + module = Reciprocal() # noqa: F405 + sample_input = (torch.randn([2, 5, 1, 3]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_relu(self): module = Relu() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),) From 1abc8d1b789ee680714f200351756d7eb186af77 Mon Sep 17 00:00:00 2001 From: boyuc Date: Sun, 25 Jan 2026 14:26:16 +0800 Subject: [PATCH 5/7] Update README.md QNN OP table --- backends/qualcomm/builders/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index a0643bd4f1d..bca15cba0fa 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -422,7 +422,7 @@ Please help update following table if you are contributing new operators: | ElementWiseSquaredDifference | ✗ | | ElementWiseSquareRoot | ✓ | | ElementWiseSubtract | ✓ | -| ElementWiseUnary | ✗ | +| ElementWiseUnary | ✓ | | ElementWiseXor | ✓ | | Elu | ✓ | | ExpandDims | ✓ | From 8912f1db15a14e2fbfd751d89df9da26624a0fe5 Mon Sep 17 00:00:00 2001 From: boyuc Date: Tue, 27 Jan 2026 16:49:25 +0800 Subject: [PATCH 6/7] Address review comments - Remove logic for inserting constant tensor (leverage lift constant pass) - Rename Reciprocal Builder to Unary - Remove redundant check for node.op=="call_function" --- .../qualcomm/_passes/decompose_reciprocal.py | 39 ++++++------------- backends/qualcomm/builders/__init__.py | 4 +- .../{op_reciprocal.py => op_unary.py} | 2 +- 3 files changed, 14 insertions(+), 31 deletions(-) rename backends/qualcomm/builders/{op_reciprocal.py => op_unary.py} (98%) diff --git a/backends/qualcomm/_passes/decompose_reciprocal.py b/backends/qualcomm/_passes/decompose_reciprocal.py index b18ac167950..bfa24dd6b60 100644 --- a/backends/qualcomm/_passes/decompose_reciprocal.py +++ b/backends/qualcomm/_passes/decompose_reciprocal.py @@ -18,39 +18,22 @@ def __init__(self): def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph for node in graph.nodes: - if node.op == "call_function" and node.target in { + if node.target in { torch.ops.aten.reciprocal.default, }: reciprocal_node = node reciprocal_node_input = node.args[0] - - # Create tensor of ones with same shape and dtype as input - fake_val = reciprocal_node_input.meta["val"] - ones_tensor = torch.ones(*fake_val.size(), dtype=fake_val.dtype) - - # Generate unique name and register buffer - buffer_name = get_new_attr_name_with_prefix("_ones_constant_")( - graph_module - ) - graph_module.register_buffer(buffer_name, ones_tensor) - with graph_module.graph.inserting_after(reciprocal_node_input): - # Create get_attr node for the ones tensor - ones_node = graph.get_attr(buffer_name) - ones_node.meta = copy_meta(reciprocal_node.meta) - ones_node.meta["val"] = reciprocal_node_input.meta["val"].clone() - - with graph_module.graph.inserting_after(ones_node): - # Create division node: ones / input - div_node = graph.call_function( - torch.ops.aten.div.Tensor, - (ones_node, reciprocal_node_input), - ) - div_node.meta = copy_meta(reciprocal_node.meta) - - # Replace all uses of reciprocal with division - for user in reciprocal_node.users.copy(): - user.replace_input_with(reciprocal_node, div_node) + # Create division node + div_node = graph.call_function( + torch.ops.aten.div.Tensor, + (1, reciprocal_node_input), + ) + div_node.meta = copy_meta(reciprocal_node.meta) + + # Replace all uses of reciprocal with division + for user in reciprocal_node.users.copy(): + user.replace_input_with(reciprocal_node, div_node) graph.eliminate_dead_code() graph_module.recompile() diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 6b8f95c211c..a5b740f931a 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -77,7 +77,7 @@ op_pow, op_prelu, op_quantize, - op_reciprocal, + op_unary, op_relu, op_repeat, op_reshape, @@ -186,7 +186,7 @@ op_pow, op_prelu, op_quantize, - op_reciprocal, + op_unary, op_relu, op_repeat, op_reshape, diff --git a/backends/qualcomm/builders/op_reciprocal.py b/backends/qualcomm/builders/op_unary.py similarity index 98% rename from backends/qualcomm/builders/op_reciprocal.py rename to backends/qualcomm/builders/op_unary.py index 909297a4272..b2814933679 100644 --- a/backends/qualcomm/builders/op_reciprocal.py +++ b/backends/qualcomm/builders/op_unary.py @@ -17,7 +17,7 @@ @register_node_visitor -class Reciprocal(NodeVisitor): +class Unary(NodeVisitor): target = ["aten.reciprocal.default"] def __init__(self, *args) -> None: From 3af1b5e8dd66fd4c618fb82b87bf2ae026b1e4b7 Mon Sep 17 00:00:00 2001 From: boyuc Date: Tue, 27 Jan 2026 17:09:57 +0800 Subject: [PATCH 7/7] Run linter --- backends/qualcomm/_passes/decompose_reciprocal.py | 1 - backends/qualcomm/builders/__init__.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/qualcomm/_passes/decompose_reciprocal.py b/backends/qualcomm/_passes/decompose_reciprocal.py index bfa24dd6b60..405d99d1171 100644 --- a/backends/qualcomm/_passes/decompose_reciprocal.py +++ b/backends/qualcomm/_passes/decompose_reciprocal.py @@ -6,7 +6,6 @@ import torch from executorch.exir.pass_base import ExportPass, PassResult -from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix from .utils import copy_meta diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index a5b740f931a..efe9434ff0b 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -77,7 +77,6 @@ op_pow, op_prelu, op_quantize, - op_unary, op_relu, op_repeat, op_reshape, @@ -105,6 +104,7 @@ op_to, op_topk, op_transpose, + op_unary, op_unbind, op_unsqueeze, op_upsample_bilinear2d,