diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 49e92719ba15..ca09c0f1cbf5 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -707,14 +707,62 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } -// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv1d_transpose -// and unit test for mixed_precision +InferLayoutOutput InferLayoutConv1dTranspose( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const auto* attrs = call->attrs.as(); + LayoutDecision data_layout, weight_layout, output_layout; + ObjectPtr new_attrs = ffi::make_object(*attrs); + + auto it = desired_layouts.find("relax.nn.conv1d_transpose"); + if (it != desired_layouts.end()) { + Layout desired_data_layout = (*it).second[0]; + Layout desired_weight_layout = (*it).second[1]; + Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; + ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) + << "Axis swap only"; + ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) + << "Axis swap only"; + data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, desired_data_layout); + weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, desired_weight_layout); + output_layout = TransposeLike(InitialLayout(3), attrs->out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + } else { + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; + new_attrs->data_layout = + TransposeLike(attrs->data_layout, InitialLayout(3), data_layout->layout).name(); + new_attrs->kernel_layout = + TransposeLike(attrs->kernel_layout, InitialLayout(3), weight_layout->layout).name(); + new_attrs->out_layout = + TransposeLike(attrs->out_layout, InitialLayout(3), output_layout->layout).name(); + } + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); +} + +Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType& out_dtype) { + const auto* conv1d_transpose_attrs = call->attrs.as(); + return Downcast( + conv1d_transpose(call->args[0], call->args[1], conv1d_transpose_attrs->strides, + conv1d_transpose_attrs->padding, conv1d_transpose_attrs->output_padding, + conv1d_transpose_attrs->dilation, conv1d_transpose_attrs->groups, + conv1d_transpose_attrs->data_layout, conv1d_transpose_attrs->kernel_layout, + conv1d_transpose_attrs->out_layout, out_dtype)); +} + TVM_REGISTER_OP("relax.nn.conv1d_transpose") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv1dTranspose) + .set_attr("FRelaxInferLayout", InferLayoutConv1dTranspose) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) + .set_attr("FInferMixedPrecision", InferMixedPrecisionConv1dTranspose) .set_attr("FPurity", Bool(true)); /* relax.nn.conv2d_transpose */ @@ -857,14 +905,95 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); } -// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv2d_transpose -// and unit test for mixed_precision +InferLayoutOutput InferLayoutConv2dTranspose( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const auto* attrs = call->attrs.as(); + LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + LayoutDecision weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + LayoutDecision output_layout; + ObjectPtr new_attrs = ffi::make_object(*attrs); + + auto it = desired_layouts.find("relax.nn.conv2d_transpose"); + if (it != desired_layouts.end()) { + Layout desired_data_layout = (*it).second[0]; + Layout desired_weight_layout = (*it).second[1]; + Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + + Layout input_layout = Layout(attrs->data_layout); + Layout kernel_layout = Layout(attrs->kernel_layout); + Layout out_layout = Layout(attrs->out_layout); + + if (desired_data_layout.ndim_primal() == input_layout.ndim() && + desired_weight_layout.ndim_primal() == kernel_layout.ndim() && + desired_output_layout.ndim_primal() == out_layout.ndim()) { + data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout); + weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout); + output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } else { + auto data_si = GetStructInfo(call->args[0]); + auto kernel_si = GetStructInfo(call->args[1]); + TensorStructInfo data_sinfo = data_si.as().value(); + TensorStructInfo kernel_sinfo = kernel_si.as().value(); + ffi::Optional data_shape = + ffi::GetRef(data_sinfo->shape.as()); + ffi::Optional kernel_shape = + ffi::GetRef(kernel_sinfo->shape.as()); + + bool can_data_proved = + CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); + bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, desired_weight_layout, + kernel_shape.value()->values); + + if (can_data_proved && can_kernel_proved) { + data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout, desired_data_layout); + weight_layout = + TransposeSubLayoutLike(InitialLayout(4), kernel_layout, desired_weight_layout); + output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } else { + data_layout = LayoutDecision(InitialLayout(4)); + weight_layout = LayoutDecision(InitialLayout(4)); + } + } + } + + output_layout = data_layout; + new_attrs->data_layout = + TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); + new_attrs->kernel_layout = + TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name(); + new_attrs->out_layout = + TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name(); + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); +} + +Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType& out_dtype) { + const auto* conv2d_transpose_attrs = call->attrs.as(); + return Downcast( + conv2d_transpose(call->args[0], call->args[1], conv2d_transpose_attrs->strides, + conv2d_transpose_attrs->padding, conv2d_transpose_attrs->output_padding, + conv2d_transpose_attrs->dilation, conv2d_transpose_attrs->groups, + conv2d_transpose_attrs->data_layout, conv2d_transpose_attrs->kernel_layout, + conv2d_transpose_attrs->out_layout, out_dtype)); +} + TVM_REGISTER_OP("relax.nn.conv2d_transpose") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoConv2dTranspose) + .set_attr("FRelaxInferLayout", InferLayoutConv2dTranspose) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) + .set_attr("FInferMixedPrecision", InferMixedPrecisionConv2dTranspose) .set_attr("FPurity", Bool(true)); } // namespace relax diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py index 588dc9b1b19c..9b913138df12 100644 --- a/tests/python/relax/test_op_nn_convolution.py +++ b/tests/python/relax/test_op_nn_convolution.py @@ -782,6 +782,25 @@ def test_conv1d_transpose_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.conv1d_transpose(x1, w0)) +def test_conv1d_transpose_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) + w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8")) + w1 = relax.Var("w", R.Tensor((3, 4, 3), "int8")) + + _check_inference( + bb, + relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d_transpose(x1, w1, out_dtype="int32"), + relax.TensorStructInfo((2, 4, 30), "int32"), + ) + + def test_conv2d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") @@ -1571,6 +1590,25 @@ def test_conv2d_transpose_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.conv2d_transpose(x1, w0)) +def test_conv2d_transpose_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) + w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) + w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8")) + + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d_transpose(x1, w1, out_dtype="int32"), + relax.TensorStructInfo((2, 4, 30, 30), "int32"), + ) + + def test_conv3d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm")