Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,15 @@
- arg_meta: null
kernel_name: impl::generic::quantized_relu_asym8u_asym8u_per_tensor_out

- func: cadence::quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
- func: cadence::quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_out
kernel_name: impl::generic::quantized_max_pool2d_nchw_out

- func: cadence::quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::generic::quantized_max_pool2d_nhwc_out

- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
Expand Down
55 changes: 51 additions & 4 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,16 @@ def register_fake(
)

lib.define(
"quantized_max_pool2d(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
"quantized_max_pool2d_nchw(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
"quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
)
lib.define(
"quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
Expand Down Expand Up @@ -2277,8 +2283,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
return input.new_empty(input.size(), dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d")
def quantized_max_pool2d_meta(
@register_fake("cadence::quantized_max_pool2d_nchw")
def quantized_max_pool2d_nchw_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
Expand Down Expand Up @@ -2318,6 +2324,47 @@ def quantized_max_pool2d_meta(
return input.new_empty([batch, channels, height_out, width_out], dtype=input.dtype)


@register_fake("cadence::quantized_max_pool2d_nhwc")
def quantized_max_pool2d_nhwc_meta(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
assert (
len(kernel_size) == 2
), f"kernel_size must have 2 elements, got {len(kernel_size)}"
assert len(stride) == 2, f"stride must have 2 elements, got {len(stride)}"
assert len(padding) == 2, f"padding must have 2 elements, got {len(padding)}"
assert len(dilation) == 2, f"dilation must have 2 elements, got {len(dilation)}"
assert (
len(input.size()) == 4
), f"input must be 4D (N, H, W, C), got {len(input.size())}D"

batch = input.size(0)
height_in = input.size(1)
width_in = input.size(2)
channels = input.size(3)

height_out_raw = (
height_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
) / stride[0] + 1
width_out_raw = (
width_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
) / stride[1] + 1

if ceil_mode:
height_out = ceil(height_out_raw)
width_out = ceil(width_out_raw)
else:
height_out = int(height_out_raw)
width_out = int(width_out_raw)

return input.new_empty([batch, height_out, width_out, channels], dtype=input.dtype)


@register_fake("cadence::fully_connected")
def fully_connected_meta(
src: torch.Tensor,
Expand Down
7 changes: 5 additions & 2 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d.default
return torch.ops.cadence.quantized_max_pool2d_nchw.default


class MaxPool2dWithoutIndicesPattern(QuantizationPattern):
Expand Down Expand Up @@ -498,7 +498,10 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_max_pool2d.default
return torch.ops.cadence.quantized_max_pool2d_nchw.default


# This is a base class for ReLU


# This is a base class for ReLU, since it can be used with two different aten ops
Expand Down
35 changes: 33 additions & 2 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,8 +1868,8 @@ def rms_norm(
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)


@impl_tracked(m, "quantized_max_pool2d")
def quantized_max_pool2d(
@impl_tracked(m, "quantized_max_pool2d_nchw")
def quantized_max_pool2d_nchw(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
Expand Down Expand Up @@ -1897,6 +1897,37 @@ def quantized_max_pool2d(
)


@impl_tracked(m, "quantized_max_pool2d_nhwc")
def quantized_max_pool2d_nhwc(
input: torch.Tensor,
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
) -> torch.Tensor:
"""
Quantized max pooling in NHWC layout.
Converts NHWC→NCHW, performs max pooling, then converts back NCHW→NHWC.
"""
# Convert NHWC [N, H, W, C] to NCHW [N, C, H, W]
input_nchw = input.permute(0, 3, 1, 2).contiguous()

# Call the NCHW version
output_nchw = quantized_max_pool2d_nchw(
input_nchw,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
)

# Convert NCHW [N, C, H_out, W_out] back to NHWC [N, H_out, W_out, C]
return output_nchw.permute(0, 2, 3, 1).contiguous()


@impl_tracked(m, "where_Scalar")
def where_Scalar(
condition: torch.Tensor,
Expand Down
62 changes: 62 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,67 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
return True


@register_cadence_pass(CadencePassAttribute(opt_level=3))
class ReplaceMaxPool2dWithChannelLastMaxPool2dPass(RemoveOrReplacePassInterface):
"""
Replace NCHW max pooling with NHWC (channel-last) max pooling by adding
permute operations before and after the max pooling.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [
exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
]

def _change_nchw_to_nhwc(
self, graph: torch.fx.Graph, node: torch.fx.Node
) -> torch.fx.Node:
"""Convert NCHW format to NHWC format."""
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, [0, 2, 3, 1]), {}
)
permute_node.meta = node.meta
return permute_node

def _change_nhwc_to_nchw(
self, graph: torch.fx.Graph, node: torch.fx.Node
) -> torch.fx.Node:
"""Convert NHWC format to NCHW format."""
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, [0, 3, 1, 2]), {}
)
permute_node.meta = node.meta
return permute_node

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
graph = node.graph

# Get input node
input_node = cast(torch.fx.Node, node.args[0])

with graph.inserting_before(node):
# Convert input from NCHW to NHWC
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)

# Create the NHWC max pooling with the same args (kernel_size, stride, padding, dilation, ceil_mode)
new_args = (input_nhwc,) + tuple(node.args[1:])

new_pool = graph.call_function(
exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default,
new_args,
node.kwargs,
)
new_pool.meta = node.meta

# Convert output back from NHWC to NCHW
nchw_output = self._change_nhwc_to_nchw(graph, new_pool)

# Replace all uses with the final output
node.replace_all_uses_with(nchw_output)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=3))
class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface):
"""
Expand Down Expand Up @@ -2561,6 +2622,7 @@ class CadenceReplaceOpsInGraph:
ReplacePadWithCatPass,
ReplaceConstantPadNdWithSlicePass,
ReplaceConvWithChannelLastConvPass,
ReplaceMaxPool2dWithChannelLastMaxPool2dPass,
ReplaceTrivialConvWithLinear,
ReplaceConvWithIm2RowAndLinear,
ReplaceTransposedConvWithLinearPass,
Expand Down
54 changes: 54 additions & 0 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ReplaceLinearWithFullyConnectedOpPass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplaceMatmulWithTransposedMatmulPass,
ReplaceMaxPool2dWithChannelLastMaxPool2dPass,
ReplaceMMWithAddMMPass,
ReplaceMulTensorWithMulAndFullOpsPass,
ReplaceNopTransposeOrPermuteWithViewPass,
Expand Down Expand Up @@ -2586,6 +2587,59 @@ def test_cat_insert_transpose(self) -> None:
)


class TestReplaceMaxPool2dWithChannelLastMaxPool2dPass(unittest.TestCase):
def test_replace_max_pool2d_nchw_with_nhwc(self) -> None:
# Create a graph with a single quantized_max_pool2d_nchw node.
x = torch.randint(0, 100, (1, 3, 8, 8), dtype=torch.int8)
gm = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
args=(x, [2, 2], [2, 2], [0, 0], [1, 1], False),
)
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_max_pool2d_nchw.default), 1
)
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)

# Deepcopy before the pass
original = copy.deepcopy(gm)

# Apply replacement pass.
p = ReplaceMaxPool2dWithChannelLastMaxPool2dPass()
result = p.call(gm)
self.assertTrue(result.modified)
gm_after_replacement = result.graph_module

# Check that replacement was made.
self.assertEqual(
count_node(
gm_after_replacement,
exir_ops.edge.cadence.quantized_max_pool2d_nhwc.default,
),
1,
)
self.assertEqual(
count_node(
gm_after_replacement,
exir_ops.edge.cadence.quantized_max_pool2d_nchw.default,
),
0,
)
# Two permutes: one for input NCHW->NHWC, one for output NHWC->NCHW
self.assertEqual(
count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default),
2,
)

# Validate numerical accuracy
validate(
original,
gm_after_replacement,
(x,),
"ReplaceMaxPool2dWithChannelLastMaxPool2dPass",
)


class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase):
def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]:
builder = GraphBuilder()
Expand Down
12 changes: 6 additions & 6 deletions backends/cadence/generic/operators/op_quantized_max_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ using ::executorch::runtime::KernelRuntimeContext;
namespace {

template <typename T>
void quantized_max_pool2d_impl(
void quantized_max_pool2d_nchw_impl(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
Expand Down Expand Up @@ -98,7 +98,7 @@ void quantized_max_pool2d_impl(

} // namespace

Tensor& quantized_max_pool2d_out(
Tensor& quantized_max_pool2d_nchw_out(
ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& input,
IntArrayRef kernel_size,
Expand All @@ -107,24 +107,24 @@ Tensor& quantized_max_pool2d_out(
IntArrayRef dilation,
bool ceil_mode,
Tensor& output) {
#define typed_quantized_max_pool2d(ctype, dtype) \
#define typed_quantized_max_pool2d_nchw(ctype, dtype) \
case ScalarType::dtype: { \
quantized_max_pool2d_impl<ctype>( \
quantized_max_pool2d_nchw_impl<ctype>( \
input, kernel_size, stride, padding, dilation, ceil_mode, output); \
break; \
}

ScalarType dtype = input.scalar_type();
// NOLINTBEGIN(clang-diagnostic-switch-enum)
switch (dtype) {
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d)
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_max_pool2d_nchw)
default:
ET_DCHECK_MSG(
false, "Unhandled dtype %s", torch::executor::toString(dtype));
}
// NOLINTEND(clang-diagnostic-switch-enum)

#undef typed_quantized_max_pool2d
#undef typed_quantized_max_pool2d_nchw
return output;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace impl {
namespace generic {
namespace native {

::executorch::aten::Tensor& quantized_max_pool2d_out(
::executorch::aten::Tensor& quantized_max_pool2d_nchw_out(
::executorch::runtime::KernelRuntimeContext& ctx,
const ::executorch::aten::Tensor& input,
::executorch::aten::IntArrayRef kernel_size,
Expand Down
Loading
Loading