-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax] Add FInferMixedPrecision and FRelaxInferLayout for conv transpose ops #18629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -707,14 +707,62 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& | |
| return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); | ||
| } | ||
|
|
||
| // TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv1d_transpose | ||
| // and unit test for mixed_precision | ||
| InferLayoutOutput InferLayoutConv1dTranspose( | ||
| const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts, | ||
| const VarLayoutMap& var_layout_map) { | ||
| const auto* attrs = call->attrs.as<Conv1DTransposeAttrs>(); | ||
| LayoutDecision data_layout, weight_layout, output_layout; | ||
| ObjectPtr<Conv1DTransposeAttrs> new_attrs = ffi::make_object<Conv1DTransposeAttrs>(*attrs); | ||
|
|
||
| auto it = desired_layouts.find("relax.nn.conv1d_transpose"); | ||
| if (it != desired_layouts.end()) { | ||
| Layout desired_data_layout = (*it).second[0]; | ||
| Layout desired_weight_layout = (*it).second[1]; | ||
| Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; | ||
| ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; | ||
| ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) | ||
| << "Axis swap only"; | ||
| ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) | ||
| << "Axis swap only"; | ||
| data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, desired_data_layout); | ||
| weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, desired_weight_layout); | ||
| output_layout = TransposeLike(InitialLayout(3), attrs->out_layout, desired_output_layout); | ||
| new_attrs->data_layout = (*it).second[0]; | ||
| new_attrs->kernel_layout = (*it).second[1]; | ||
| new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; | ||
| } else { | ||
| data_layout = GetLayoutDecision(var_layout_map, call->args[0]); | ||
| weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); | ||
| output_layout = data_layout; | ||
| new_attrs->data_layout = | ||
| TransposeLike(attrs->data_layout, InitialLayout(3), data_layout->layout).name(); | ||
| new_attrs->kernel_layout = | ||
| TransposeLike(attrs->kernel_layout, InitialLayout(3), weight_layout->layout).name(); | ||
| new_attrs->out_layout = | ||
| TransposeLike(attrs->out_layout, InitialLayout(3), output_layout->layout).name(); | ||
| } | ||
| return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); | ||
| } | ||
|
|
||
| Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType& out_dtype) { | ||
| const auto* conv1d_transpose_attrs = call->attrs.as<Conv1DTransposeAttrs>(); | ||
| return Downcast<Call>( | ||
| conv1d_transpose(call->args[0], call->args[1], conv1d_transpose_attrs->strides, | ||
| conv1d_transpose_attrs->padding, conv1d_transpose_attrs->output_padding, | ||
| conv1d_transpose_attrs->dilation, conv1d_transpose_attrs->groups, | ||
| conv1d_transpose_attrs->data_layout, conv1d_transpose_attrs->kernel_layout, | ||
| conv1d_transpose_attrs->out_layout, out_dtype)); | ||
| } | ||
|
|
||
| TVM_REGISTER_OP("relax.nn.conv1d_transpose") | ||
| .set_num_inputs(2) | ||
| .add_argument("data", "Tensor", "The input tensor.") | ||
| .add_argument("weight", "Tensor", "The weight tensor.") | ||
| .set_attrs_type<Conv1DTransposeAttrs>() | ||
| .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv1dTranspose) | ||
| .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv1dTranspose) | ||
| .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) | ||
| .set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionConv1dTranspose) | ||
| .set_attr<Bool>("FPurity", Bool(true)); | ||
|
|
||
| /* relax.nn.conv2d_transpose */ | ||
|
|
@@ -857,14 +905,95 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& | |
| return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); | ||
| } | ||
|
|
||
| // TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv2d_transpose | ||
| // and unit test for mixed_precision | ||
| InferLayoutOutput InferLayoutConv2dTranspose( | ||
| const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts, | ||
| const VarLayoutMap& var_layout_map) { | ||
| const auto* attrs = call->attrs.as<Conv2DTransposeAttrs>(); | ||
| LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]); | ||
| LayoutDecision weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); | ||
| LayoutDecision output_layout; | ||
| ObjectPtr<Conv2DTransposeAttrs> new_attrs = ffi::make_object<Conv2DTransposeAttrs>(*attrs); | ||
|
|
||
| auto it = desired_layouts.find("relax.nn.conv2d_transpose"); | ||
| if (it != desired_layouts.end()) { | ||
| Layout desired_data_layout = (*it).second[0]; | ||
| Layout desired_weight_layout = (*it).second[1]; | ||
| Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; | ||
|
Comment on lines
+919
to
+921
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is repeated code for accessing desired layouts. To improve readability and maintainability, you can store the desired layout strings in local variables at the beginning of the |
||
|
|
||
| Layout input_layout = Layout(attrs->data_layout); | ||
| Layout kernel_layout = Layout(attrs->kernel_layout); | ||
| Layout out_layout = Layout(attrs->out_layout); | ||
|
|
||
| if (desired_data_layout.ndim_primal() == input_layout.ndim() && | ||
| desired_weight_layout.ndim_primal() == kernel_layout.ndim() && | ||
| desired_output_layout.ndim_primal() == out_layout.ndim()) { | ||
| data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout); | ||
| weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout); | ||
| output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout); | ||
| new_attrs->data_layout = (*it).second[0]; | ||
| new_attrs->kernel_layout = (*it).second[1]; | ||
| new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; | ||
| return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); | ||
| } else { | ||
| auto data_si = GetStructInfo(call->args[0]); | ||
| auto kernel_si = GetStructInfo(call->args[1]); | ||
| TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value(); | ||
| TensorStructInfo kernel_sinfo = kernel_si.as<TensorStructInfo>().value(); | ||
| ffi::Optional<ShapeExpr> data_shape = | ||
| ffi::GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>()); | ||
| ffi::Optional<ShapeExpr> kernel_shape = | ||
| ffi::GetRef<ShapeExpr>(kernel_sinfo->shape.as<ShapeExprNode>()); | ||
|
|
||
| bool can_data_proved = | ||
| CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values); | ||
| bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, desired_weight_layout, | ||
| kernel_shape.value()->values); | ||
|
|
||
| if (can_data_proved && can_kernel_proved) { | ||
| data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout, desired_data_layout); | ||
| weight_layout = | ||
| TransposeSubLayoutLike(InitialLayout(4), kernel_layout, desired_weight_layout); | ||
| output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout, desired_output_layout); | ||
| new_attrs->data_layout = (*it).second[0]; | ||
| new_attrs->kernel_layout = (*it).second[1]; | ||
| new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; | ||
| return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); | ||
| } else { | ||
| data_layout = LayoutDecision(InitialLayout(4)); | ||
| weight_layout = LayoutDecision(InitialLayout(4)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| output_layout = data_layout; | ||
| new_attrs->data_layout = | ||
| TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); | ||
| new_attrs->kernel_layout = | ||
| TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name(); | ||
| new_attrs->out_layout = | ||
| TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name(); | ||
| return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); | ||
| } | ||
|
|
||
| Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType& out_dtype) { | ||
| const auto* conv2d_transpose_attrs = call->attrs.as<Conv2DTransposeAttrs>(); | ||
| return Downcast<Call>( | ||
| conv2d_transpose(call->args[0], call->args[1], conv2d_transpose_attrs->strides, | ||
| conv2d_transpose_attrs->padding, conv2d_transpose_attrs->output_padding, | ||
| conv2d_transpose_attrs->dilation, conv2d_transpose_attrs->groups, | ||
| conv2d_transpose_attrs->data_layout, conv2d_transpose_attrs->kernel_layout, | ||
| conv2d_transpose_attrs->out_layout, out_dtype)); | ||
| } | ||
|
|
||
| TVM_REGISTER_OP("relax.nn.conv2d_transpose") | ||
| .set_num_inputs(2) | ||
| .add_argument("data", "Tensor", "The input tensor.") | ||
| .add_argument("weight", "Tensor", "The weight tensor.") | ||
| .set_attrs_type<Conv2DTransposeAttrs>() | ||
| .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2dTranspose) | ||
| .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv2dTranspose) | ||
| .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) | ||
| .set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionConv2dTranspose) | ||
| .set_attr<Bool>("FPurity", Bool(true)); | ||
|
|
||
| } // namespace relax | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -782,6 +782,25 @@ def test_conv1d_transpose_infer_struct_info_wrong_input_type(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb.normalize(relax.op.nn.conv1d_transpose(x1, w0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_conv1d_transpose_infer_struct_info_mixed_precision(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb = relax.BlockBuilder() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w1 = relax.Var("w", R.Tensor((3, 4, 3), "int8")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _check_inference( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relax.TensorStructInfo((2, 4, 30), "float32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _check_inference( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relax.op.nn.conv1d_transpose(x1, w1, out_dtype="int32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relax.TensorStructInfo((2, 4, 30), "int32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+786
to
+801
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For better readability, consider using more descriptive variable names that indicate the data type, like
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_conv2d_infer_struct_info(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb = relax.BlockBuilder() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vdev0 = VDevice("llvm") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1571,6 +1590,25 @@ def test_conv2d_transpose_infer_struct_info_wrong_input_type(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb.normalize(relax.op.nn.conv2d_transpose(x1, w0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_conv2d_transpose_infer_struct_info_mixed_precision(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb = relax.BlockBuilder() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _check_inference( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relax.TensorStructInfo((2, 4, 30, 30), "float32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _check_inference( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relax.op.nn.conv2d_transpose(x1, w1, out_dtype="int32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| relax.TensorStructInfo((2, 4, 30, 30), "int32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1594
to
+1609
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For better readability, consider using more descriptive variable names that indicate the data type, like
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_conv3d_infer_struct_info(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bb = relax.BlockBuilder() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vdev0 = VDevice("llvm") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is repeated code for accessing desired layouts. To improve readability and maintainability, you can store the desired layout strings in local variables and reuse them. This avoids repeating
(*it).secondaccess and the ternary logic for the output layout. You can then use these variables on lines 730-732.