Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 196 additions & 65 deletions backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,12 @@ void xa_opt_quantized_conv2d_nhwc(
bool conv1d = input.dim() == 3;
constexpr int kNnlibMaxDim = 4;

if (input.scalar_type() == ScalarType::Char) {
// Combined path for int8 (Char) and uint8 (Byte)
if (input.scalar_type() == ScalarType::Char ||
input.scalar_type() == ScalarType::Byte) {
bool is_uint8 = input.scalar_type() == ScalarType::Byte;

// Use WORD8* for both int8 and uint8 (with casts for uint8)
WORD8* __restrict__ p_out =
(WORD8* __restrict__)out.mutable_data_ptr<int8_t>();
WORD8* __restrict__ p_inp =
Expand Down Expand Up @@ -213,9 +218,6 @@ void xa_opt_quantized_conv2d_nhwc(
WORD32 dilation_width = dilation[1];
WORD32 dilation_height = dilation[0];

// WORD32* kernel_bias_ptr =
// (WORD32*)weight_zero_point.const_data_ptr<int32_t>();

WORD32 input_zero_bias = -in_zero_point;
WORD32 kernel_zero_bias = -weight_zero_point;

Expand All @@ -237,8 +239,11 @@ void xa_opt_quantized_conv2d_nhwc(

WORD32 scratch_size = 0;

// Standard conv2d (groups == 1)
// int8 uses xa_nn_conv2d_per_chan_sym8sxasym8s
// uint8 uses xa_nn_conv2d_std_asym8uxasym8u (matching NCHW)
if (groups == 1) {
WORD32 out_data_format = 1;
WORD32 out_data_format = 0; // 0 = NHWC output format

scratch_size = xa_nn_conv2d_getsize(
input_height,
Expand Down Expand Up @@ -266,44 +271,129 @@ void xa_opt_quantized_conv2d_nhwc(

p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8);

for (int _n = 0; _n < batches; _n++) {
WORD8* in_batch =
p_inp + _n * input_channels * input_height * input_width;
WORD8* out_batch = p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_per_chan_sym8sxasym8s(
out_batch,
in_batch,
p_kernel,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
kernel_channels,
dilation_height,
dilation_width,
out_channels,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
out_multiplier32,
out_shift32,
out_zero_bias,
out_data_format,
p_scratch);
if (is_uint8) {
// uint8 standard conv2d uses xa_nn_conv2d_std_asym8uxasym8u
WORD32 out_multiplier = out_multiplier32[0];
WORD32 out_shift = out_shift32[0];

for (int _n = 0; _n < batches; _n++) {
UWORD8* in_batch =
(UWORD8*)p_inp + _n * input_channels * input_height * input_width;
UWORD8* out_batch =
(UWORD8*)p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_std_asym8uxasym8u(
out_batch,
in_batch,
(UWORD8*)p_kernel,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
out_channels,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
kernel_zero_bias,
out_multiplier,
out_shift,
out_zero_bias,
out_data_format,
p_scratch);
}
} else {
// int8 standard conv2d uses xa_nn_conv2d_per_chan_sym8sxasym8s
for (int _n = 0; _n < batches; _n++) {
WORD8* in_batch =
p_inp + _n * input_channels * input_height * input_width;
WORD8* out_batch =
p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_per_chan_sym8sxasym8s(
out_batch,
in_batch,
p_kernel,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
kernel_channels,
dilation_height,
dilation_width,
out_channels,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
out_multiplier32,
out_shift32,
out_zero_bias,
out_data_format,
p_scratch);
}
}
return;
}

// Depthwise conv2d (groups == input_channels)
// int8 uses xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s
// uint8 uses xa_nn_conv2d_depthwise_asym8uxasym8u
if (groups == input_channels) {
WORD32 channels_multiplier = out_channels / input_channels;

// NHWC weight comes as [OC, KH, KW, IC] (4D) or [OC, KW, IC] (3D for conv1d)
// where IC=1 for depthwise. nnlib expects weight as [KH, KW, OC]

// Allocate buffer for transposed weight
WORD8* ptr_kernel = (WORD8*)kernels::allocate_temp_memory(
ctx,
(out_channels * kernel_height * kernel_width + 8) * sizeof(WORD8));
WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr_kernel, 8);

// Handle both conv1d (3D weight) and conv2d (4D weight) cases
if (conv1d) {
// Conv1d: transpose from [OC, KW, 1] to [1, KW, OC]
WORD32 p_kernel_inp_shape[kNnlibMaxDim] = {out_channels, kernel_width, 1, 1};
WORD32 p_kernel_out_shape[kNnlibMaxDim] = {1, kernel_width, out_channels, 1};
WORD32 p_kernel_permute[kNnlibMaxDim] = {2, 1, 0, 3};

xa_nn_transpose_8_8(
pkernel,
p_kernel_out_shape,
p_kernel,
p_kernel_inp_shape,
p_kernel_permute,
kNnlibMaxDim,
kNnlibMaxDim);
} else {
// Conv2d: transpose from [OC, KH, KW, 1] to [KH, KW, OC]
WORD32 p_kernel_inp_shape[kNnlibMaxDim] = {
out_channels, kernel_height, kernel_width, 1};
WORD32 p_kernel_out_shape[kNnlibMaxDim] = {
kernel_height, kernel_width, out_channels, 1};
WORD32 p_kernel_permute[kNnlibMaxDim] = {1, 2, 0, 3};

xa_nn_transpose_8_8(
pkernel,
p_kernel_out_shape,
p_kernel,
p_kernel_inp_shape,
p_kernel_permute,
kNnlibMaxDim,
kNnlibMaxDim);
}

scratch_size = xa_nn_conv2d_depthwise_getsize(
input_height,
input_width,
Expand All @@ -326,35 +416,76 @@ void xa_opt_quantized_conv2d_nhwc(

p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8);

for (int _n = 0; _n < batches; _n++) {
WORD8* in_batch =
p_inp + _n * input_channels * input_height * input_width;
WORD8* out_batch = p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s(
out_batch,
p_kernel,
in_batch,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
channels_multiplier,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
out_multiplier32,
out_shift32,
out_zero_bias,
0, // NHWC
0, // NHWC
p_scratch);
if (is_uint8) {
// uint8 depthwise uses xa_nn_conv2d_depthwise_asym8uxasym8u
WORD32 out_multiplier = out_multiplier32[0];
WORD32 out_shift = out_shift32[0];

for (int _n = 0; _n < batches; _n++) {
UWORD8* in_batch =
(UWORD8*)p_inp + _n * input_channels * input_height * input_width;
UWORD8* out_batch =
(UWORD8*)p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_depthwise_asym8uxasym8u(
out_batch,
(UWORD8*)pkernel,
in_batch,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
channels_multiplier,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
kernel_zero_bias,
out_multiplier,
out_shift,
out_zero_bias,
0, // NHWC out
0, // NHWC inp
p_scratch);
}
} else {
// int8 depthwise uses xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s
for (int _n = 0; _n < batches; _n++) {
WORD8* in_batch =
p_inp + _n * input_channels * input_height * input_width;
WORD8* out_batch =
p_out + _n * out_channels * out_height * out_width;

xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s(
out_batch,
pkernel,
in_batch,
p_bias,
input_height,
input_width,
input_channels,
kernel_height,
kernel_width,
channels_multiplier,
x_stride,
y_stride,
x_padding,
y_padding,
out_height,
out_width,
input_zero_bias,
out_multiplier32,
out_shift32,
out_zero_bias,
0, // NHWC out
0, // NHWC inp
p_scratch);
}
}

return;
Expand Down
Loading