diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 80de190fedf..1c4dd3e06f3 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -309,10 +309,15 @@ - arg_meta: null kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out -- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::generic::quantized_max_pool2d_out + kernel_name: impl::generic::quantized_max_pool2d_nchw_out + +- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_max_pool2d_nhwc_out - func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 601d54fe49b..060702becec 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -214,10 +214,16 @@ def register_fake( ) lib.define( - "quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" + "quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" ) lib.define( - "quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor" +) +lib.define( + "quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( @@ -2277,8 +2283,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta( return input.new_empty(input.size(), dtype=input.dtype) -@register_fake("cadence::quantized_max_pool2d") -def quantized_max_pool2d_meta( +@register_fake("cadence::quantized_max_pool2d_nchw") +def quantized_max_pool2d_nchw_meta( input: torch.Tensor, kernel_size: list[int], stride: list[int], @@ -2318,6 +2324,47 @@ def quantized_max_pool2d_meta( return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype) +@register_fake("cadence::quantized_max_pool2d_nhwc") +def quantized_max_pool2d_nhwc_meta( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + assert ( + len(kernel_size) == 2 + ), f"kernel_size must have 2 elements, got {len(kernel_size)}" + assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}" + assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}" + assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}" + assert ( + len(input.size()) == 4 + ), f"input must be 4D (N, H, W, C), got {len(input.size())}D" + + batch = input.size(0) + height_in = input.size(1) + width_in = input.size(2) + channels = input.size(3) + + height_out_raw = ( + height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) / stride[0] + 1 + width_out_raw = ( + width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) / stride[1] + 1 + + if ceil_mode: + height_out = ceil(height_out_raw) + width_out = ceil(width_out_raw) + else: + height_out = int(height_out_raw) + width_out = int(width_out_raw) + + return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype) + + @register_fake("cadence::fully_connected") def fully_connected_meta( src: torch.Tensor, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 0d52c004dea..204f066ebf4 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -459,7 +459,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_max_pool2d.default + return torch.ops.cadence.quantized_max_pool2d_nchw.default class MaxPool2dWithoutIndicesPattern(QuantizationPattern): @@ -498,7 +498,10 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_max_pool2d.default + return torch.ops.cadence.quantized_max_pool2d_nchw.default + + +# This is a base class for ReLU # This is a base class for ReLU, since it can be used with two different aten ops diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index ed8b3ca60ae..f985718c150 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1868,8 +1868,8 @@ def rms_norm( return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X) -@impl_tracked(m, "quantized_max_pool2d") -def quantized_max_pool2d( +@impl_tracked(m, "quantized_max_pool2d_nchw") +def quantized_max_pool2d_nchw( input: torch.Tensor, kernel_size: list[int], stride: list[int], @@ -1897,6 +1897,37 @@ def quantized_max_pool2d( ) +@impl_tracked(m, "quantized_max_pool2d_nhwc") +def quantized_max_pool2d_nhwc( + input: torch.Tensor, + kernel_size: list[int], + stride: list[int], + padding: list[int], + dilation: list[int], + ceil_mode: bool, +) -> torch.Tensor: + """ + Quantized max pooling in NHWC layout. + + Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC. + """ + # Convert NHWC [N, H, W, C] to NCHW [N, C, H, W] + input_nchw = input.permute(0, 3, 1, 2).contiguous() + + # Call the NCHW version + output_nchw = quantized_max_pool2d_nchw( + input_nchw, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + # Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C] + return output_nchw.permute(0, 2, 3, 1).contiguous() + + @impl_tracked(m, "where_Scalar") def where_Scalar( condition: torch.Tensor, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 14a35c01baf..6e6e98af267 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1182,6 +1182,67 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=3)) +class ReplaceMaxPool2dWithChannelLastMaxPool2dPass(RemoveOrReplacePassInterface): + """ + Replace NCHW max pooling with NHWC (channel-last) max pooling by adding + permute operations before and after the max pooling. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + ] + + def _change_nchw_to_nhwc( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NCHW format to NHWC format.""" + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, [0, 2, 3, 1]), {} + ) + permute_node.meta = node.meta + return permute_node + + def _change_nhwc_to_nchw( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NHWC format to NCHW format.""" + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, [0, 3, 1, 2]), {} + ) + permute_node.meta = node.meta + return permute_node + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + graph = node.graph + + # Get input node + input_node = cast(torch.fx.Node, node.args[0]) + + with graph.inserting_before(node): + # Convert input from NCHW to NHWC + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) + + # Create the NHWC max pooling with the same args (kernel_size, stride, padding, dilation, ceil_mode) + new_args = (input_nhwc,) + tuple(node.args[1:]) + + new_pool = graph.call_function( + exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default, + new_args, + node.kwargs, + ) + new_pool.meta = node.meta + + # Convert output back from NHWC to NCHW + nchw_output = self._change_nhwc_to_nchw(graph, new_pool) + + # Replace all uses with the final output + node.replace_all_uses_with(nchw_output) + return True + + @register_cadence_pass(CadencePassAttribute(opt_level=3)) class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): """ @@ -2561,6 +2622,7 @@ class CadenceReplaceOpsInGraph: ReplacePadWithCatPass, ReplaceConstantPadNdWithSlicePass, ReplaceConvWithChannelLastConvPass, + ReplaceMaxPool2dWithChannelLastMaxPool2dPass, ReplaceTrivialConvWithLinear, ReplaceConvWithIm2RowAndLinear, ReplaceTransposedConvWithLinearPass, diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 95d470644a0..5d9f8c0784b 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -36,6 +36,7 @@ ReplaceLinearWithFullyConnectedOpPass, ReplaceLogicalNotBooleanWhereWithWherePass, ReplaceMatmulWithTransposedMatmulPass, + ReplaceMaxPool2dWithChannelLastMaxPool2dPass, ReplaceMMWithAddMMPass, ReplaceMulTensorWithMulAndFullOpsPass, ReplaceNopTransposeOrPermuteWithViewPass, @@ -2586,6 +2587,59 @@ def test_cat_insert_transpose(self) -> None: ) +class TestReplaceMaxPool2dWithChannelLastMaxPool2dPass(unittest.TestCase): + def test_replace_max_pool2d_nchw_with_nhwc(self) -> None: + # Create a graph with a single quantized_max_pool2d_nchw node. + x = torch.randint(0, 100, (1, 3, 8, 8), dtype=torch.int8) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False), + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_max_pool2d_nchw.default), 1 + ) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + + # Deepcopy before the pass + original = copy.deepcopy(gm) + + # Apply replacement pass. + p = ReplaceMaxPool2dWithChannelLastMaxPool2dPass() + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Check that replacement was made. + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default, + ), + 1, + ) + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_max_pool2d_nchw.default, + ), + 0, + ) + # Two permutes: one for input NCHW->NHWC, one for output NHWC->NCHW + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), + 2, + ) + + # Validate numerical accuracy + validate( + original, + gm_after_replacement, + (x,), + "ReplaceMaxPool2dWithChannelLastMaxPool2dPass", + ) + + class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase): def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]: builder = GraphBuilder() diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp index b241b0851a9..f843ad84080 100644 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.cpp @@ -27,7 +27,7 @@ using ::executorch::runtime::KernelRuntimeContext; namespace { template -void quantized_max_pool2d_impl( +void quantized_max_pool2d_nchw_impl( const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -98,7 +98,7 @@ void quantized_max_pool2d_impl( } // namespace -Tensor& quantized_max_pool2d_out( +Tensor& quantized_max_pool2d_nchw_out( ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, IntArrayRef kernel_size, @@ -107,9 +107,9 @@ Tensor& quantized_max_pool2d_out( IntArrayRef dilation, bool ceil_mode, Tensor& output) { -#define typed_quantized_max_pool2d(ctype, dtype) \ +#define typed_quantized_max_pool2d_nchw(ctype, dtype) \ case ScalarType::dtype: { \ - quantized_max_pool2d_impl( \ + quantized_max_pool2d_nchw_impl( \ input, kernel_size, stride, padding, dilation, ceil_mode, output); \ break; \ } @@ -117,14 +117,14 @@ Tensor& quantized_max_pool2d_out( ScalarType dtype = input.scalar_type(); // NOLINTBEGIN(clang-diagnostic-switch-enum) switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d) + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nchw) default: ET_DCHECK_MSG( false, "Unhandled dtype %s", torch::executor::toString(dtype)); } // NOLINTEND(clang-diagnostic-switch-enum) -#undef typed_quantized_max_pool2d +#undef typed_quantized_max_pool2d_nchw return output; } diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d.h b/backends/cadence/generic/operators/op_quantized_max_pool2d.h index 07f406a37a7..453dd5a2582 100644 --- a/backends/cadence/generic/operators/op_quantized_max_pool2d.h +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d.h @@ -15,7 +15,7 @@ namespace impl { namespace generic { namespace native { -::executorch::aten::Tensor& quantized_max_pool2d_out( +::executorch::aten::Tensor& quantized_max_pool2d_nchw_out( ::executorch::runtime::KernelRuntimeContext& ctx, const ::executorch::aten::Tensor& input, ::executorch::aten::IntArrayRef kernel_size, diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp new file mode 100644 index 00000000000..d8f0d9e068b --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +namespace { + +template +void quantized_max_pool2d_nhwc_impl( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + ET_UNUSED bool ceil_mode, + Tensor& output) { + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = output.mutable_data_ptr(); + + // Input dimensions: [N, H, W, C] + const int64_t batch_size = input.size(0); + const int64_t in_height = input.size(1); + const int64_t in_width = input.size(2); + const int64_t channels = input.size(3); + + // Output dimensions: [N, H_out, W_out, C] + const int64_t out_height = output.size(1); + const int64_t out_width = output.size(2); + + // Pooling parameters + const int64_t kernel_h = kernel_size[0]; + const int64_t kernel_w = kernel_size[1]; + const int64_t stride_h = stride[0]; + const int64_t stride_w = stride[1]; + const int64_t pad_h = padding[0]; + const int64_t pad_w = padding[1]; + const int64_t dilation_h = dilation[0]; + const int64_t dilation_w = dilation[1]; + + for (int64_t n = 0; n < batch_size; ++n) { + for (int64_t oh = 0; oh < out_height; ++oh) { + for (int64_t ow = 0; ow < out_width; ++ow) { + const int64_t ih_start = oh * stride_h - pad_h; + const int64_t iw_start = ow * stride_w - pad_w; + + T* __restrict__ out_ptr = + out_data + ((n * out_height + oh) * out_width + ow) * channels; + + // Initialize all channels to the minimum value. + for (int64_t c = 0; c < channels; ++c) { + out_ptr[c] = std::numeric_limits::lowest(); + } + + // For each kernel position, compute element-wise max across all + // channels. The inner loop over channels is a stride-1 contiguous + // access in NHWC layout, enabling SIMD auto-vectorization. + for (int64_t kh = 0; kh < kernel_h; ++kh) { + const int64_t ih = ih_start + kh * dilation_h; + if (ih < 0 || ih >= in_height) { + continue; + } + for (int64_t kw = 0; kw < kernel_w; ++kw) { + const int64_t iw = iw_start + kw * dilation_w; + if (iw < 0 || iw >= in_width) { + continue; + } + + const T* __restrict__ in_ptr = + in_data + ((n * in_height + ih) * in_width + iw) * channels; + + for (int64_t c = 0; c < channels; ++c) { + out_ptr[c] = std::max(out_ptr[c], in_ptr[c]); + } + } + } + } + } + } +} + +} // namespace + +Tensor& quantized_max_pool2d_nhwc_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + Tensor& output) { +#define typed_quantized_max_pool2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_max_pool2d_nhwc_impl( \ + input, kernel_size, stride, padding, dilation, ceil_mode, output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + // NOLINTBEGIN(clang-diagnostic-switch-enum) + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nhwc) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + // NOLINTEND(clang-diagnostic-switch-enum) + +#undef typed_quantized_max_pool2d_nhwc + return output; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h new file mode 100644 index 00000000000..2b0c02e4bb7 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_max_pool2d_nhwc.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_max_pool2d_nhwc_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + ::executorch::aten::IntArrayRef kernel_size, + ::executorch::aten::IntArrayRef stride, + ::executorch::aten::IntArrayRef padding, + ::executorch::aten::IntArrayRef dilation, + bool ceil_mode, + ::executorch::aten::Tensor& output); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index bf1de9e009a..fa6708a188e 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -225,6 +225,18 @@ def define_common_targets(): visibility = ["PUBLIC"], ) + runtime.cxx_library( + name = "op_quantized_max_pool2d_nhwc", + srcs = ["op_quantized_max_pool2d_nhwc.cpp"], + exported_headers = ["op_quantized_max_pool2d_nhwc.h"], + platforms = CXX, + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ":cadence_type_util", + ], + visibility = ["PUBLIC"], + ) + runtime.cxx_library( name = "op_quantized_matmul", srcs = ["op_quantized_matmul.cpp"],