-
Notifications
You must be signed in to change notification settings - Fork 884
Qualcomm/op reciprocal #18220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Qualcomm/op reciprocal #18220
Changes from all commits
d71c615
a3f262a
5980e51
d00f795
1abc8d1
8912f1d
3af1b5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright (c) Qualcomm Innovation Center, Inc. | ||
| # All rights reserved | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| from executorch.exir.pass_base import ExportPass, PassResult | ||
|
|
||
| from .utils import copy_meta | ||
|
|
||
|
|
||
| class DecomposeReciprocal(ExportPass): | ||
| def __init__(self): | ||
| super(DecomposeReciprocal, self).__init__() | ||
|
|
||
| def call(self, graph_module: torch.fx.GraphModule): | ||
| graph = graph_module.graph | ||
| for node in graph.nodes: | ||
| if node.target in { | ||
| torch.ops.aten.reciprocal.default, | ||
| }: | ||
| reciprocal_node = node | ||
| reciprocal_node_input = node.args[0] | ||
| with graph_module.graph.inserting_after(reciprocal_node_input): | ||
| # Create division node | ||
| div_node = graph.call_function( | ||
| torch.ops.aten.div.Tensor, | ||
| (1, reciprocal_node_input), | ||
| ) | ||
| div_node.meta = copy_meta(reciprocal_node.meta) | ||
|
|
||
| # Replace all uses of reciprocal with division | ||
| for user in reciprocal_node.users.copy(): | ||
| user.replace_input_with(reciprocal_node, div_node) | ||
|
|
||
| graph.eliminate_dead_code() | ||
| graph_module.recompile() | ||
| return PassResult(graph_module, True) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # Copyright (c) Qualcomm Innovation Center, Inc. | ||
| # All rights reserved | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| from typing import Dict | ||
|
|
||
| import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnWrapper | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from executorch.backends.qualcomm.utils.constants import QCOM_DATA | ||
|
|
||
| from .node_visitor import NodeVisitor | ||
| from .node_visitor_manager import register_node_visitor | ||
| from .qnn_constants import OpElementWiseUnary, QNN_OP_PACKAGE_NAME_QTI_AISW | ||
|
|
||
|
|
||
| @register_node_visitor | ||
| class Unary(NodeVisitor): | ||
| target = ["aten.reciprocal.default"] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, why do we have a pass that converts reciprocal to div and yet have a op builder for reciprocal?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @abhinaykukkadapu |
||
|
|
||
| def __init__(self, *args) -> None: | ||
| super().__init__(*args) | ||
|
|
||
| def define_node( | ||
| self, | ||
| node: torch.fx.Node, | ||
| nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], | ||
| ) -> PyQnnWrapper.PyQnnOpWrapper: | ||
| input_node = self.get_node(node.args[0]) | ||
| input_tensor = self.get_tensor(input_node, node) | ||
| reciprocal_inp_tensor_wrapper = self.define_tensor( | ||
| input_node, | ||
| node, | ||
| input_tensor, | ||
| PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
| nodes_to_wrappers, | ||
| ) | ||
| reciprocal_input_tensors = [reciprocal_inp_tensor_wrapper] | ||
|
|
||
| output_tensor = self.get_tensor(node, node) | ||
| output_tensor_wrapper = self.define_tensor( | ||
| node, | ||
| node, | ||
| output_tensor, | ||
| PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
| nodes_to_wrappers, | ||
| ) | ||
| reciprocal_output_tensors = [output_tensor_wrapper] | ||
|
|
||
| reciprocal_op = PyQnnWrapper.PyQnnOpWrapper( | ||
| node.name, | ||
| QNN_OP_PACKAGE_NAME_QTI_AISW, | ||
| OpElementWiseUnary.op_name, | ||
| ) | ||
| reciprocal_op.AddInputTensors(reciprocal_input_tensors) | ||
| reciprocal_op.AddOutputTensors(reciprocal_output_tensors) | ||
|
|
||
| reciprocal_op.AddScalarParam( | ||
| OpElementWiseUnary.param_operation, | ||
| PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, | ||
| {QCOM_DATA: np.uint32(OpElementWiseUnary.Operation.RECIPROCAL)}, | ||
| ) | ||
|
|
||
| return reciprocal_op | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if we intend to move all the unary ops under this class in future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we will move all unary ops into this NodeVisitor in future. (Will also do similar for elemetwise-binary ops)