Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 133 additions & 4 deletions src/relax/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -707,14 +707,62 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder&
return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
}

// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv1d_transpose
// and unit test for mixed_precision
InferLayoutOutput InferLayoutConv1dTranspose(
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
const auto* attrs = call->attrs.as<Conv1DTransposeAttrs>();
LayoutDecision data_layout, weight_layout, output_layout;
ObjectPtr<Conv1DTransposeAttrs> new_attrs = ffi::make_object<Conv1DTransposeAttrs>(*attrs);

auto it = desired_layouts.find("relax.nn.conv1d_transpose");
if (it != desired_layouts.end()) {
Layout desired_data_layout = (*it).second[0];
Layout desired_weight_layout = (*it).second[1];
Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
Comment on lines +719 to +721
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is repeated code for accessing desired layouts. To improve readability and maintainability, you can store the desired layout strings in local variables and reuse them. This avoids repeating (*it).second access and the ternary logic for the output layout. You can then use these variables on lines 730-732.

ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only";
ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal())
<< "Axis swap only";
ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal())
<< "Axis swap only";
data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, desired_data_layout);
weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, desired_weight_layout);
output_layout = TransposeLike(InitialLayout(3), attrs->out_layout, desired_output_layout);
new_attrs->data_layout = (*it).second[0];
new_attrs->kernel_layout = (*it).second[1];
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
} else {
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
output_layout = data_layout;
new_attrs->data_layout =
TransposeLike(attrs->data_layout, InitialLayout(3), data_layout->layout).name();
new_attrs->kernel_layout =
TransposeLike(attrs->kernel_layout, InitialLayout(3), weight_layout->layout).name();
new_attrs->out_layout =
TransposeLike(attrs->out_layout, InitialLayout(3), output_layout->layout).name();
}
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
}

Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType& out_dtype) {
const auto* conv1d_transpose_attrs = call->attrs.as<Conv1DTransposeAttrs>();
return Downcast<Call>(
conv1d_transpose(call->args[0], call->args[1], conv1d_transpose_attrs->strides,
conv1d_transpose_attrs->padding, conv1d_transpose_attrs->output_padding,
conv1d_transpose_attrs->dilation, conv1d_transpose_attrs->groups,
conv1d_transpose_attrs->data_layout, conv1d_transpose_attrs->kernel_layout,
conv1d_transpose_attrs->out_layout, out_dtype));
}

TVM_REGISTER_OP("relax.nn.conv1d_transpose")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_attrs_type<Conv1DTransposeAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv1dTranspose)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv1dTranspose)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways)
.set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionConv1dTranspose)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.conv2d_transpose */
Expand Down Expand Up @@ -857,14 +905,95 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder&
return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
}

// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv2d_transpose
// and unit test for mixed_precision
InferLayoutOutput InferLayoutConv2dTranspose(
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
const auto* attrs = call->attrs.as<Conv2DTransposeAttrs>();
LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
LayoutDecision weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
LayoutDecision output_layout;
ObjectPtr<Conv2DTransposeAttrs> new_attrs = ffi::make_object<Conv2DTransposeAttrs>(*attrs);

auto it = desired_layouts.find("relax.nn.conv2d_transpose");
if (it != desired_layouts.end()) {
Layout desired_data_layout = (*it).second[0];
Layout desired_weight_layout = (*it).second[1];
Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
Comment on lines +919 to +921
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is repeated code for accessing desired layouts. To improve readability and maintainability, you can store the desired layout strings in local variables at the beginning of the if block and reuse them. This avoids repeating (*it).second access and the ternary logic for the output layout. You can then use these variables to update new_attrs in lines 933-935 and 957-959.


Layout input_layout = Layout(attrs->data_layout);
Layout kernel_layout = Layout(attrs->kernel_layout);
Layout out_layout = Layout(attrs->out_layout);

if (desired_data_layout.ndim_primal() == input_layout.ndim() &&
desired_weight_layout.ndim_primal() == kernel_layout.ndim() &&
desired_output_layout.ndim_primal() == out_layout.ndim()) {
data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout);
weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout);
output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout);
new_attrs->data_layout = (*it).second[0];
new_attrs->kernel_layout = (*it).second[1];
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
} else {
auto data_si = GetStructInfo(call->args[0]);
auto kernel_si = GetStructInfo(call->args[1]);
TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
TensorStructInfo kernel_sinfo = kernel_si.as<TensorStructInfo>().value();
ffi::Optional<ShapeExpr> data_shape =
ffi::GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
ffi::Optional<ShapeExpr> kernel_shape =
ffi::GetRef<ShapeExpr>(kernel_sinfo->shape.as<ShapeExprNode>());

bool can_data_proved =
CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values);
bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, desired_weight_layout,
kernel_shape.value()->values);

if (can_data_proved && can_kernel_proved) {
data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout, desired_data_layout);
weight_layout =
TransposeSubLayoutLike(InitialLayout(4), kernel_layout, desired_weight_layout);
output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout, desired_output_layout);
new_attrs->data_layout = (*it).second[0];
new_attrs->kernel_layout = (*it).second[1];
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
} else {
data_layout = LayoutDecision(InitialLayout(4));
weight_layout = LayoutDecision(InitialLayout(4));
}
}
}

output_layout = data_layout;
new_attrs->data_layout =
TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name();
new_attrs->kernel_layout =
TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name();
new_attrs->out_layout =
TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name();
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
}

Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType& out_dtype) {
const auto* conv2d_transpose_attrs = call->attrs.as<Conv2DTransposeAttrs>();
return Downcast<Call>(
conv2d_transpose(call->args[0], call->args[1], conv2d_transpose_attrs->strides,
conv2d_transpose_attrs->padding, conv2d_transpose_attrs->output_padding,
conv2d_transpose_attrs->dilation, conv2d_transpose_attrs->groups,
conv2d_transpose_attrs->data_layout, conv2d_transpose_attrs->kernel_layout,
conv2d_transpose_attrs->out_layout, out_dtype));
}

TVM_REGISTER_OP("relax.nn.conv2d_transpose")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_attrs_type<Conv2DTransposeAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2dTranspose)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv2dTranspose)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways)
.set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionConv2dTranspose)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relax/test_op_nn_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,25 @@ def test_conv1d_transpose_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.conv1d_transpose(x1, w0))


def test_conv1d_transpose_infer_struct_info_mixed_precision():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16"))
w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16"))
x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8"))
w1 = relax.Var("w", R.Tensor((3, 4, 3), "int8"))

_check_inference(
bb,
relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float32"),
relax.TensorStructInfo((2, 4, 30), "float32"),
)
_check_inference(
bb,
relax.op.nn.conv1d_transpose(x1, w1, out_dtype="int32"),
relax.TensorStructInfo((2, 4, 30), "int32"),
)
Comment on lines +786 to +801
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability, consider using more descriptive variable names that indicate the data type, like x_f16, w_f16, x_i8, and w_i8.

Suggested change
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16"))
w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16"))
x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8"))
w1 = relax.Var("w", R.Tensor((3, 4, 3), "int8"))
_check_inference(
bb,
relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float32"),
relax.TensorStructInfo((2, 4, 30), "float32"),
)
_check_inference(
bb,
relax.op.nn.conv1d_transpose(x1, w1, out_dtype="int32"),
relax.TensorStructInfo((2, 4, 30), "int32"),
)
bb = relax.BlockBuilder()
x_f16 = relax.Var("x", R.Tensor((2, 3, 28), "float16"))
w_f16 = relax.Var("w", R.Tensor((3, 4, 3), "float16"))
x_i8 = relax.Var("x", R.Tensor((2, 3, 28), "int8"))
w_i8 = relax.Var("w", R.Tensor((3, 4, 3), "int8"))
_check_inference(
bb,
relax.op.nn.conv1d_transpose(x_f16, w_f16, out_dtype="float32"),
relax.TensorStructInfo((2, 4, 30), "float32"),
)
_check_inference(
bb,
relax.op.nn.conv1d_transpose(x_i8, w_i8, out_dtype="int32"),
relax.TensorStructInfo((2, 4, 30), "int32"),
)



def test_conv2d_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
Expand Down Expand Up @@ -1571,6 +1590,25 @@ def test_conv2d_transpose_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))


def test_conv2d_transpose_infer_struct_info_mixed_precision():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16"))
x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))

_check_inference(
bb,
relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float32"),
relax.TensorStructInfo((2, 4, 30, 30), "float32"),
)
_check_inference(
bb,
relax.op.nn.conv2d_transpose(x1, w1, out_dtype="int32"),
relax.TensorStructInfo((2, 4, 30, 30), "int32"),
)
Comment on lines +1594 to +1609
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability, consider using more descriptive variable names that indicate the data type, like x_f16, w_f16, x_i8, and w_i8.

Suggested change
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16"))
x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
_check_inference(
bb,
relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float32"),
relax.TensorStructInfo((2, 4, 30, 30), "float32"),
)
_check_inference(
bb,
relax.op.nn.conv2d_transpose(x1, w1, out_dtype="int32"),
relax.TensorStructInfo((2, 4, 30, 30), "int32"),
)
bb = relax.BlockBuilder()
x_f16 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
w_f16 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16"))
x_i8 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
w_i8 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
_check_inference(
bb,
relax.op.nn.conv2d_transpose(x_f16, w_f16, out_dtype="float32"),
relax.TensorStructInfo((2, 4, 30, 30), "float32"),
)
_check_inference(
bb,
relax.op.nn.conv2d_transpose(x_i8, w_i8, out_dtype="int32"),
relax.TensorStructInfo((2, 4, 30, 30), "int32"),
)



def test_conv3d_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
Expand Down
Loading