From 649e2c32851b03fa7e8d79bd410323406c6e76f4 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 23 Nov 2025 17:38:24 +0800 Subject: [PATCH 1/6] [#17876]Fix TVM crashes with default relax pipeline when opt_level=1: InternalError: Check failed: (slot->value_computed) is false --- src/relax/backend/vm/vm_shape_lower.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index bbc227d1d559..3b192700e3ec 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -399,6 +399,23 @@ class VMShapeLowerMutator return ffi::GetRef(op); } + // Check if all expressions are computed if not mark variables as ready and trigger computation + for (const PrimExpr& expr : op->values) { + if (!expr->IsInstance()) { + auto it = slot_map_.find(expr); + if (it != slot_map_.end() && !it->second->value_computed) { + // If it's a variable, mark it as ready for computation + if (expr.as()) { + it->second->value_computed = true; + ready_vars_.push_back(it->second); + } + } + } + } + + // Trigger computation for any expression that are now ready + this->EmitOutstandingPrimExprCompute(); + ffi::Array args = {shape_heap_, PrimValue::Int64(static_cast(op->values.size()))}; for (PrimExpr expr : op->values) { From dd374b744ab7fc0d9d2054175a6b9e23ba2e2a6f Mon Sep 17 00:00:00 2001 From: cchung100m Date: Mon, 24 Nov 2025 18:40:06 +0800 Subject: [PATCH 2/6] Add test case: test_composite_shape_expression --- .../test_backend_transform_shape_lower.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 177d036107c1..277f8de638ff 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -893,5 +893,32 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): assert_structural_equal(Expected, After) +def test_composite_shape_expression(): + """When a ShapeExpr contains composite PrimExpr that haven't been computed yet, + VMShapeLower should trigger computation before processing the shape. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor(("x_0", "x_1", "x_2", "x_3"), "float32")) -> R.Tensor: + R.func_attr({"relax.force_pure": True}) + x_0 = T.int64() + x_1 = T.int64() + x_2 = T.int64() + x_3 = T.int64() + # This creates a composite expression that was causing the crash: + # T.int64(4) * (x_0 * x_1 * x_2 * x_3) + new_shape = R.shape([T.int64(4) * (x_0 * x_1 * x_2 * x_3)]) + return R.reshape(x, new_shape) + + # The test shoud not crash during VMShapeLower + # We don't need to validate teh exact output, just that it doesn't crash + after = relax.transform.VMShapeLower(emit_err_ctx=False)(Before) + + # The actual output structure is not as important as not crashing + assert after is not None + + if __name__ == "__main__": tvm.testing.main() From 35faf3d096c53b68702839c838acd1ca3c528926 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Tue, 2 Dec 2025 20:04:30 +0800 Subject: [PATCH 3/6] Revert the incorrect solution --- src/relax/backend/vm/vm_shape_lower.cc | 17 ------------ .../test_backend_transform_shape_lower.py | 27 ------------------- 2 files changed, 44 deletions(-) diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 3b192700e3ec..bbc227d1d559 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -399,23 +399,6 @@ class VMShapeLowerMutator return ffi::GetRef(op); } - // Check if all expressions are computed if not mark variables as ready and trigger computation - for (const PrimExpr& expr : op->values) { - if (!expr->IsInstance()) { - auto it = slot_map_.find(expr); - if (it != slot_map_.end() && !it->second->value_computed) { - // If it's a variable, mark it as ready for computation - if (expr.as()) { - it->second->value_computed = true; - ready_vars_.push_back(it->second); - } - } - } - } - - // Trigger computation for any expression that are now ready - this->EmitOutstandingPrimExprCompute(); - ffi::Array args = {shape_heap_, PrimValue::Int64(static_cast(op->values.size()))}; for (PrimExpr expr : op->values) { diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 277f8de638ff..177d036107c1 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -893,32 +893,5 @@ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"): assert_structural_equal(Expected, After) -def test_composite_shape_expression(): - """When a ShapeExpr contains composite PrimExpr that haven't been computed yet, - VMShapeLower should trigger computation before processing the shape. - """ - - @tvm.script.ir_module - class Before: - @R.function - def main(x: R.Tensor(("x_0", "x_1", "x_2", "x_3"), "float32")) -> R.Tensor: - R.func_attr({"relax.force_pure": True}) - x_0 = T.int64() - x_1 = T.int64() - x_2 = T.int64() - x_3 = T.int64() - # This creates a composite expression that was causing the crash: - # T.int64(4) * (x_0 * x_1 * x_2 * x_3) - new_shape = R.shape([T.int64(4) * (x_0 * x_1 * x_2 * x_3)]) - return R.reshape(x, new_shape) - - # The test shoud not crash during VMShapeLower - # We don't need to validate teh exact output, just that it doesn't crash - after = relax.transform.VMShapeLower(emit_err_ctx=False)(Before) - - # The actual output structure is not as important as not crashing - assert after is not None - - if __name__ == "__main__": tvm.testing.main() From a14d98ae1bbbb2e621a959ad0dfd070ee670f940 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 14 Dec 2025 13:06:27 +0800 Subject: [PATCH 4/6] new pass: ShapeExprCanonicalizer --- .../tvm/relax/backend/cpu_generic/pipeline.py | 1 + python/tvm/relax/backend/cuda/pipeline.py | 1 + .../tvm/relax/backend/gpu_generic/pipeline.py | 1 + python/tvm/relax/backend/rocm/pipeline.py | 1 + python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 20 ++ .../transform/canonicalize_shape_expr.cc | 253 ++++++++++++++++++ 7 files changed, 278 insertions(+) create mode 100644 src/relax/transform/canonicalize_shape_expr.cc diff --git a/python/tvm/relax/backend/cpu_generic/pipeline.py b/python/tvm/relax/backend/cpu_generic/pipeline.py index 74d951b817b1..f36efd788ece 100644 --- a/python/tvm/relax/backend/cpu_generic/pipeline.py +++ b/python/tvm/relax/backend/cpu_generic/pipeline.py @@ -53,6 +53,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), relax.transform.ComputePrimValue(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), ] diff --git a/python/tvm/relax/backend/cuda/pipeline.py b/python/tvm/relax/backend/cuda/pipeline.py index d5c4c0856165..f0d3fb6f5ce7 100644 --- a/python/tvm/relax/backend/cuda/pipeline.py +++ b/python/tvm/relax/backend/cuda/pipeline.py @@ -65,6 +65,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), relax.transform.ComputePrimValue(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), ] diff --git a/python/tvm/relax/backend/gpu_generic/pipeline.py b/python/tvm/relax/backend/gpu_generic/pipeline.py index 86c60114c699..d9c69f5ef786 100644 --- a/python/tvm/relax/backend/gpu_generic/pipeline.py +++ b/python/tvm/relax/backend/gpu_generic/pipeline.py @@ -64,6 +64,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), relax.transform.ComputePrimValue(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), ] diff --git a/python/tvm/relax/backend/rocm/pipeline.py b/python/tvm/relax/backend/rocm/pipeline.py index e74039ca8634..2edb135a785f 100644 --- a/python/tvm/relax/backend/rocm/pipeline.py +++ b/python/tvm/relax/backend/rocm/pipeline.py @@ -64,6 +64,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), relax.transform.ComputePrimValue(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), ] diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index dacbc667be2b..72e23e089519 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -28,6 +28,7 @@ BundleModelParams, CallTIRRewrite, CanonicalizeBindings, + CanonicalizeShapeExpr, CombineParallelMatmul, ComputePrimValue, ConvertLayout, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 46efc17e3d4f..1e59b3c0804c 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -735,6 +735,26 @@ def FoldConstant() -> tvm.ir.transform.Pass: return _ffi_api.FoldConstant() # type: ignore +def CanonicalizeShapeExpr() -> tvm.ir.transform.Pass: + """Canonicalize ShapeExpr by lifting compound PrimExpr into separate bindings. + + VMShapeLower can only handle ShapeExpr where each dimension is either: + - IntImm (concrete integer constant) + - tir::Var (symbolic variable) + + This pass lifts compound PrimExpr (e.g., n+1, 4*n*m, etc.) into separate shape bindings + with MatchCast to extract symbolic variables, ensuring VMShapeLower receives only + canonical shape expressions. + + This pass should be applied after ComputePrimValue and before VMShapeLower. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CanonicalizeShapeExpr() # type: ignore + + def ExpandTupleArguments() -> tvm.ir.transform.Pass: """Expand tuple arguments to internal functions diff --git a/src/relax/transform/canonicalize_shape_expr.cc b/src/relax/transform/canonicalize_shape_expr.cc new file mode 100644 index 000000000000..1f035228cd85 --- /dev/null +++ b/src/relax/transform/canonicalize_shape_expr.cc @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/canonicalize_shape_expr.cc + * \brief Cannonicalize ShapeExpr by lifting compound PrimExpr into separate bindings. + * + * VMShapeLower can only handle ShapeExpr where each dimension is either: + * - IntImm (concrete integer constant) + * - tir::Var (symbolic variable) + * + * This pass lifts compound PrimExpr (e.g., n+1, 4*n*m, etc.) into separate shape bindings + * with MatchCast to extract symbolic variables, ensuring VMShapeLower receives only + * cannonical shape expressions. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +namespace { + +/*! + * \brief Check if a PrimExpr is cannonical for VMShapeLower + * + * VMShapeLower can only handle: + * - IntImm: concrete integer constants + * - tir::Var: symbolic variables that can be stored/loaded at runtime + * + * Any other expression (arithmetic, casts, etc.) is compound and needs canonicalization. + */ +bool IsCanonicalPrimExpr(const PrimExpr& expr) { + return expr->IsInstance() || expr->IsInstance(); +} + +/*! + * \brief Mutator to canonicalize ShapeExpr in struct info + * + * This pass handles ShapeExpr canonicalization by: + * 1. Detecting compound PrimExpr in ShapeExpr dimensions + * 2. Lifting them into separate ShapeExpr bindings + * 3. Using MatchCast to extract values into fresh symbolic tir::Var + * 4. Replacing compound expressions with these canonical vars + */ +class ShapeExprCanonicalizer : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const FunctionNode* func) override { + // Reset state for each function + auto cached_compound_to_var = compound_expr_to_var_; + auto cached_counter = symbolic_var_counter_; + + auto result = ExprMutator::VisitExpr_(func); + + compound_expr_to_var_ = cached_compound_to_var; + symbolic_var_counter_ = cached_counter; + + return result; + } + + /*! + * \brief Override VisitVarDef to canonicalize struct_info + * + * This is where we intercept variable definitions and canonicalize any + * compound PrimExpr in their TensorStructInfo shapes. + */ + Var VisitVarDef(const Var& var) override { + auto sinfo = GetStructInfo(var); + + // Check if we need to canonicalize the struct_info + auto canonical_sinfo = CanonicalizeStructInfo(sinfo); + + if (canonical_sinfo.same_as(sinfo)) { + // No changes needed + return ExprMutator::VisitVarDef(var); + } + + // Create a new var with canonicalized strcut_info + if (var->IsInstance()) { + return DataflowVar(var->vid, canonical_sinfo, var->span); + } + return Var(var->vid, canonical_sinfo, var->span); + } + + private: + /*! + * \brief Canonicalize struct info by lifting compound shape expressions + */ + StructInfo CanonicalizeStructInfo(const StructInfo& sinfo) { + if (auto tensor_sinfo = sinfo.as()) { + return CanonicalizeTensorStructInfo(ffi::GetRef(tensor_sinfo)); + } else if (auto tuple_sinfo = sinfo.as()) { + return CanonicalizeTupleStructInfo(ffi::GetRef(tuple_sinfo)); + } + return sinfo; + } + + /*! + * \brief Canonicalize TensorStructInfo by handling compound shape expressions + */ + TensorStructInfo CanonicalizeTensorStructInfo(const TensorStructInfo& sinfo) { + if (!sinfo->shape.defined()) { + return sinfo; + } + + auto shape_expr = sinfo->shape.as(); + if (!shape_expr) { + // Shape is Var, not a ShapeExpr - no canonicalization needed + return sinfo; + } + + // Canonicalize each dimension + ffi::Array canonical_dims; + bool changed = false; + + for (const PrimExpr& dim : shape_expr->values) { + PrimExpr canonical_dim = CanonicalizeDimension(dim); + canonical_dims.push_back(canonical_dim); + changed |= !canonical_dim.same_as(dim); + } + + if (!changed) { + return sinfo; + } + + // Create new TensorStructInfo with canonicalized shape + return TensorStructInfo(ShapeExpr(canonical_dims), sinfo->dtype, sinfo->vdevice, sinfo->span); + } + + /*! + * \brief Canonicalize TupleStructInfo recursively + */ + TupleStructInfo CanonicalizeTupleStructInfo(const TupleStructInfo& sinfo) { + ffi::Array canonical_fields; + bool changed = false; + + for (const StructInfo& field : sinfo->fields) { + StructInfo canonical_field = CanonicalizeStructInfo(field); + canonical_fields.push_back(canonical_field); + changed |= !canonical_field.same_as(field); + } + + if (!changed) { + return sinfo; + } + + return TupleStructInfo(canonical_fields, sinfo->span); + } + + /*! + * \brief Canonicalize a single shape dimension + * + * If the dimension is a compound PrimExpr: + * 1. Emit a ShapeExpr binding containing the compound expression + * 2. Create a fresh symbolic tir::Var + * 3. Emit a MatchCast to bind the computed value to the symbolic var + * 4. Return the symbolic var + */ + PrimExpr CanonicalizeDimension(const PrimExpr& dim) { + // If already canonical, return as is + if (IsCanonicalPrimExpr(dim)) { + return dim; + } + + // Check if we've already canonicalized this expression + if (auto it = compound_expr_to_var_.find(dim); it != compound_expr_to_var_.end()) { + return it->second; + } + + // Create a fresh symbolic variable + tir::Var symbolic_var = CreateFreshSymbolicVar(dim->dtype); + + // Emit shape binding: shape_var = R.shape([compound_expr]) + ShapeExpr shape_value({dim}); + Var shape_var = builder_->Emit(shape_value); + + // Emit MatchCast to extract the computed value into the symbolic variable + // match_cast_var: R.Shape([symbolic_var]) = shape_var + ShapeStructInfo match_sinfo(ffi::Array{symbolic_var}); + Var match_cast_var("_", match_sinfo); + builder_->EmitNormalized(MatchCast(match_cast_var, shape_var, match_sinfo)); + + // Cache the mapping to avoid duplicate bindings + compound_expr_to_var_[dim] = symbolic_var; + + return symbolic_var; + } + + /*! + * \brief Create a fresh symbolic TIR variable + */ + tir::Var CreateFreshSymbolicVar(DataType dtype) { + std::string name = "s" + std::to_string(symbolic_var_counter_++); + return tir::Var(name, dtype); + } + + // Cache to avoid creating duplicate bindings for the same compound expression + std::unordered_map compound_expr_to_var_; + + // Counter for generating unique symbolic variable names + int symbolic_var_counter_ = 0; +}; +} // namespace + +Expr CanonicalizeShapeExpr(Expr expr) { return ShapeExprCanonicalizer()(std::move(expr)); } + +namespace transform { + +Pass CanonicalizeShapeExpr() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(relax::CanonicalizeShapeExpr(f)); + }; + return CreateFunctionPass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"CanonicalizeShapeExpr", + /*required=*/{}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.CanonicalizeShapeExpr", CanonicalizeShapeExpr); +} + +} // namespace transform + +} // namespace relax +} // namespace tvm From 25bbab105cd0d1d1f2631ff8c3e1178272daca07 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 14 Dec 2025 16:37:50 +0800 Subject: [PATCH 5/6] Add unit test case for new pass ShapeExprCanonicalizer --- .../test_transform_canonicalize_shape_expr.py | 216 ++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 tests/python/relax/test_transform_canonicalize_shape_expr.py diff --git a/tests/python/relax/test_transform_canonicalize_shape_expr.py b/tests/python/relax/test_transform_canonicalize_shape_expr.py new file mode 100644 index 000000000000..9a9a1532b6ef --- /dev/null +++ b/tests/python/relax/test_transform_canonicalize_shape_expr.py @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for the CanonicalizeShapeExpr pass""" + +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_simple_compound_shape(): + """Test canonicalization of simple compound shape expression""" + + @R.function + def before(x: R.Tensor(("n",), "float32")): + n = T.int64() + # Compound expression: n + 1 + y: R.Tensor((n + 1,), "float32") = R.zeros(R.shape([n + 1]), dtype="float32") + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + + # After canonicalization, the shape should use a symbolic var instead of n + 1 + # Check that VMShapeLower can process it + mod = relax.transform.VMShapeLower()(mod) + + # If we got here without error, the test passed + assert "main" in mod + + +def test_compound_shape_in_constant(): + """Test canonicalization when compound shape appears in constant variable struct_info""" + + @R.function + def before(x: R.Tensor(("n", "m"), "float32")): + n = T.int64() + m = T.int64() + # This pattern can occur after FoldConstant inlines shapes + # The constant variable has compound expression in its struct_info + y: R.Tensor((n * m,), "float32") = R.zeros(R.shape([n * m]), dtype="float32") + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + + mod = relax.transform.VMShapeLower()(mod) + + # If we got here without error, the test passed + assert "main" in mod + + +def test_multiply_compound_shape(): + """Test the original issue case: 4 * x_0 * x_1 * x_2 * x_3""" + + @R.function + def before(x: R.Tensor(("n", "m", "p", "q"), "float32")): + n = T.int64() + m = T.int64() + p = T.int64() + q = T.int64() + # Compound expression: 4 * n * m * p * q + y: R.Tensor((4 * n * m * p * q,), "float32") = R.zeros( + R.shape([4 * n * m * p * q]), dtype="float32" + ) + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + + mod = relax.transform.VMShapeLower()(mod) + + # If we got here without error, the test passed + assert "main" in mod + + +def test_no_change_for_canonical_shape(): + """Test that already canonical shapes are not modified""" + + @R.function + def before(x: R.Tensor(("n",), "float32")): + n = T.int64() + # Already canonical shape + y: R.Tensor((n,), "float32") = R.zeros(R.shape([n]), dtype="float32") + return y + + mod_before = tvm.IRModule.from_expr(before) + mod_after = relax.transform.CanonicalizeShapeExpr()(mod_before) + + # The mod should be unchanged (or minimally changed) + # Both should work with VMShapeLower + mode_before_lower = relax.transform.VMShapeLower()(mod_before) + mode_after_lower = relax.transform.VMShapeLower()(mod_after) + + # If we got here without error, the test passed + assert "main" in mod_before_lower + assert "main" in mode_after_lower + + +def test_no_change_for_concrete_shape(): + """Test that concrete integer shapes are not modified""" + + @R.function + def before(x: R.Tensor((10,), "float32")): + # Concrete shape + y: R.Tensor((10,), "float32") = R.zeros(R.shape([10]), dtype="float32") + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + + mod = relax.transform.VMShapeLower()(mod) + + # If we got here without error, the test passed + assert "main" in mod + + +def test_tuple_struct_info(): + """Test canonicalization with tuple struct info containing compound shapes""" + + @R.function + def before(x: R.Tensor(("n",), "float32")): + n = T.int64() + # Tuple with compound shapes + y: R.Tuple(R.Tensor((n + 1,), "float32"), R.Tensor((n * 2,), "float32")) = ( + R.zeros(R.shape([n + 1]), dtype="float32"), + R.zeros(R.shape([n * 2]), dtype="float32"), + ) + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + + mod = relax.transform.VMShapeLower()(mod) + + # If we got here without error, the test passed + assert "main" in mod + + +def test_full_pipeline_with_opt_level_1(): + """Test the full pipeline with opt_level=1""" + + @R.function + def before(x: R.Tensor(("n", "m"), "float32")): + n = T.int64() + m = T.int64() + y: R.Tensor((n * m,), "float32") = R.reshape(x, R.shape([n * m])) + return y + + mod = tvm.IRModule.from_expr(before) + + with tvm.transform.PassContext(opt_level=1): + # Apply the passes in order + mod = relax.transform.LegalizeOps()(mod) + mod = relax.transform.AnnotateTIROpPattern()(mod) + mod = relax.transform.FoldConstant()(mod) + mod = relax.transform.ComputePrimValue()(mod) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + mod = relax.transform.VMShapeLower()(mod) + + assert "main" in mod + + +if __name__ == "__main__": + import sys + + print("Running CanonicalizeShapeExpr unit tests...") + print("=" * 80) + + tests = [ + ("Simple compound shape", test_simple_compound_shape), + ("Compound shape in constant", test_compound_shape_in_constant), + ("Multiply compound shape", test_multiply_compound_shape), + ("No change for canonical shape", test_no_change_for_canonical_shape), + ("No change for concrete shape", test_no_change_for_concrete_shape), + ("Tuple struct info", test_tuple_struct_info), + ("Full pipeline with opt_level=1", test_full_pipeline_with_opt_level_1), + ] + + passed = 0 + failed = 0 + + for name, test_func in tests: + try: + print(f"\nTest: {name}") + test_func() + print("Result: PASSED") + passed += 1 + except Exception as e: + print(f"Result: FAILED: {e}") + import traceback + + traceback.print_exc() + failed += 1 + + print("\n" + "=" * 80) + print(f"Total tests run: {passed + failed}, Passed: {passed}, Failed: {failed}") + + sys.exit(0 if failed == 0 else 1) From 0002c69c1f08d40fb2ba11eedbb81b398961d545 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Fri, 26 Dec 2025 13:06:45 +0800 Subject: [PATCH 6/6] Refactor VisitExpr_ --- .../transform/canonicalize_shape_expr.cc | 236 +++++++++++------- 1 file changed, 139 insertions(+), 97 deletions(-) diff --git a/src/relax/transform/canonicalize_shape_expr.cc b/src/relax/transform/canonicalize_shape_expr.cc index 1f035228cd85..1edcef17605e 100644 --- a/src/relax/transform/canonicalize_shape_expr.cc +++ b/src/relax/transform/canonicalize_shape_expr.cc @@ -62,10 +62,10 @@ bool IsCanonicalPrimExpr(const PrimExpr& expr) { * \brief Mutator to canonicalize ShapeExpr in struct info * * This pass handles ShapeExpr canonicalization by: - * 1. Detecting compound PrimExpr in ShapeExpr dimensions - * 2. Lifting them into separate ShapeExpr bindings + * 1. Detecting compound PrimExpr in variable struct_info + * 2. Emitting ShapeExpr bindings to compute expressions * 3. Using MatchCast to extract values into fresh symbolic tir::Var - * 4. Replacing compound expressions with these canonical vars + * 4. Replacing compound expressions with these canonical vars in struct_info */ class ShapeExprCanonicalizer : public ExprMutator { public: @@ -73,116 +73,118 @@ class ShapeExprCanonicalizer : public ExprMutator { Expr VisitExpr_(const FunctionNode* func) override { // Reset state for each function - auto cached_compound_to_var = compound_expr_to_var_; - auto cached_counter = symbolic_var_counter_; - - auto result = ExprMutator::VisitExpr_(func); - - compound_expr_to_var_ = cached_compound_to_var; - symbolic_var_counter_ = cached_counter; - - return result; - } - - /*! - * \brief Override VisitVarDef to canonicalize struct_info - * - * This is where we intercept variable definitions and canonicalize any - * compound PrimExpr in their TensorStructInfo shapes. - */ - Var VisitVarDef(const Var& var) override { - auto sinfo = GetStructInfo(var); - - // Check if we need to canonicalize the struct_info - auto canonical_sinfo = CanonicalizeStructInfo(sinfo); - - if (canonical_sinfo.same_as(sinfo)) { - // No changes needed - return ExprMutator::VisitVarDef(var); + symbolic_var_counter_ = 0; + compound_expr_to_var_.clear(); + emitted_bindings_.clear(); + + // Visit params to populate var_remap_ + ffi::Array params; + bool all_params_unchanged = true; + for (Var param : func->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + if (!param.same_as(new_param)) { + var_remap_[param->vid] = new_param; + all_params_unchanged = false; + } } - // Create a new var with canonicalized strcut_info - if (var->IsInstance()) { - return DataflowVar(var->vid, canonical_sinfo, var->span); - } - return Var(var->vid, canonical_sinfo, var->span); - } + // Process the function body with proper scope setup + Expr new_body = this->VisitWithNewScope(func->body, params); - private: - /*! - * \brief Canonicalize struct info by lifting compound shape expressions - */ - StructInfo CanonicalizeStructInfo(const StructInfo& sinfo) { - if (auto tensor_sinfo = sinfo.as()) { - return CanonicalizeTensorStructInfo(ffi::GetRef(tensor_sinfo)); - } else if (auto tuple_sinfo = sinfo.as()) { - return CanonicalizeTupleStructInfo(ffi::GetRef(tuple_sinfo)); + if (all_params_unchanged && new_body.same_as(func->body)) { + return ffi::GetRef(func); } - return sinfo; + + return Function(params, new_body, func->ret_struct_info, func->is_pure, func->attrs, + func->span); } - /*! - * \brief Canonicalize TensorStructInfo by handling compound shape expressions - */ - TensorStructInfo CanonicalizeTensorStructInfo(const TensorStructInfo& sinfo) { - if (!sinfo->shape.defined()) { - return sinfo; - } + Expr VisitExpr_(const ShapeExprNode* op) override { + // Just cannonicalize ShapeExpr values by replacing compound expression with symbolic vars + // The bindings should have been emitted earlier by EmitBindingsForExpr - auto shape_expr = sinfo->shape.as(); - if (!shape_expr) { - // Shape is Var, not a ShapeExpr - no canonicalization needed - return sinfo; + // Mark a copy of values to avoid any reference issues + std::vector original_dims; + for (const PrimExpr& dim : op->values) { + original_dims.push_back(dim); } - // Canonicalize each dimension - ffi::Array canonical_dims; + ffi::Array canonical_values; bool changed = false; - for (const PrimExpr& dim : shape_expr->values) { - PrimExpr canonical_dim = CanonicalizeDimension(dim); - canonical_dims.push_back(canonical_dim); + for (const PrimExpr& dim : original_dims) { + PrimExpr canonical_dim = GetCanonicalDimension(dim); + canonical_values.push_back(canonical_dim); changed |= !canonical_dim.same_as(dim); } if (!changed) { - return sinfo; + return ffi::GetRef(op); } - // Create new TensorStructInfo with canonicalized shape - return TensorStructInfo(ShapeExpr(canonical_dims), sinfo->dtype, sinfo->vdevice, sinfo->span); + return ShapeExpr(canonical_values, op->span); } /*! - * \brief Canonicalize TupleStructInfo recursively + * \brief Scan an expression for ShapeExprs and emit bindings for compound expressions. + * This must be called BEFORE visiting the expression to ensure bindings are emitted first. */ - TupleStructInfo CanonicalizeTupleStructInfo(const TupleStructInfo& sinfo) { - ffi::Array canonical_fields; - bool changed = false; + void EmitBindingsForExpr(const Expr& expr) { + // Use a simple visitor to find ShapeExpr nodes + class ShapeExprScanner : public ExprVisitor { + public: + explicit ShapeExprScanner(ShapeExprCanonicalizer* canonicalizer) + : canonicalizer_(canonicalizer) {} + + void VisitExpr_(const ShapeExprNode* op) override { + // Make a copy of values to avoid reference issues during emission + std::vector dims; + for (const PrimExpr& dim : op->values) { + dims.push_back(dim); + } + for (const PrimExpr& dim : dims) { + if (!IsCanonicalPrimExpr(dim)) { + canonicalizer_->CanonicalizeDimension(dim); + } + } + } + + private: + ShapeExprCanonicalizer* canonicalizer_; + }; + + ShapeExprScanner scanner(this); + scanner.VisitExpr(expr); + } - for (const StructInfo& field : sinfo->fields) { - StructInfo canonical_field = CanonicalizeStructInfo(field); - canonical_fields.push_back(canonical_field); - changed |= !canonical_field.same_as(field); - } + void VisitBinding_(const VarBindingNode* binding) override { + // Emit canonicalization bindings before processing the binding. + // Scan the binding's value for ShapeExprs with compound expressions. + EmitBindingsForExpr(binding->value); - if (!changed) { - return sinfo; - } + // Let the base class handle the rest + ExprMutator::VisitBinding_(binding); + } - return TupleStructInfo(canonical_fields, sinfo->span); + void VisitBinding_(const MatchCastNode* binding) override { + // Scan the binding's value for ShapeExprs with compound expressions + EmitBindingsForExpr(binding->value); + + // Delegate to base handling + ExprMutator::VisitBinding_(binding); + } + + Var VisitVarDef(const Var& var) override { + // Don't canonicalize struct_info - just delegate to base + return ExprMutator::VisitVarDef(var); } + private: /*! - * \brief Canonicalize a single shape dimension - * - * If the dimension is a compound PrimExpr: - * 1. Emit a ShapeExpr binding containing the compound expression - * 2. Create a fresh symbolic tir::Var - * 3. Emit a MatchCast to bind the computed value to the symbolic var - * 4. Return the symbolic var + * \brief Get the canonical form of a dimension (returns the symbolic var if already emitted) */ - PrimExpr CanonicalizeDimension(const PrimExpr& dim) { + PrimExpr GetCanonicalDimension(const PrimExpr& dim) { // If already canonical, return as is if (IsCanonicalPrimExpr(dim)) { return dim; @@ -193,25 +195,62 @@ class ShapeExprCanonicalizer : public ExprMutator { return it->second; } - // Create a fresh symbolic variable + // Create a fresh symbolic variable, but don't emit yet tir::Var symbolic_var = CreateFreshSymbolicVar(dim->dtype); - // Emit shape binding: shape_var = R.shape([compound_expr]) - ShapeExpr shape_value({dim}); - Var shape_var = builder_->Emit(shape_value); - - // Emit MatchCast to extract the computed value into the symbolic variable - // match_cast_var: R.Shape([symbolic_var]) = shape_var - ShapeStructInfo match_sinfo(ffi::Array{symbolic_var}); - Var match_cast_var("_", match_sinfo); - builder_->EmitNormalized(MatchCast(match_cast_var, shape_var, match_sinfo)); - - // Cache the mapping to avoid duplicate bindings compound_expr_to_var_[dim] = symbolic_var; return symbolic_var; } + /*! + * \brief Emit bindings for a single compound dimension + * + * If the dimension is a compound PrimExpr: + * 1. Create a fresh symbolic tir::Var for the compound expression + * 2. Emit a MatchCast from a PrimValue to define the symbolic var + */ + void CanonicalizeDimension(const PrimExpr& dim) { + // If already canonical, nothing to emit + if (IsCanonicalPrimExpr(dim)) { + return; + } + + // Check if we've already emitted the bindings + if (emitted_bindings_.count(dim)) { + return; + } + + // Mark as emitted BEFORE emitting to prevent infinite recursion + emitted_bindings_.insert(dim); + + // Get or create the symbolic var for this compound expression + tir::Var symbolic_var; + auto it = compound_expr_to_var_.find(dim); + if (it != compound_expr_to_var_.end()) { + symbolic_var = it->second; + } else { + DataType dtype = dim->dtype; + symbolic_var = CreateFreshSymbolicVar(dtype); + compound_expr_to_var_[dim] = symbolic_var; + } + + // Emit a PrimValue binding with the compound expression + // This will be processed by VMShapeLower to compute the value + PrimValue prim_value(dim); + PrimStructInfo prim_sinfo(dim->dtype); + std::string prim_var_name = "_prim" + std::to_string(symbolic_var_counter_ - 1); + Var prim_var(prim_var_name, prim_sinfo); + builder_->EmitNormalized(VarBinding(prim_var, prim_value)); + + // Emit MatchCast to extract the computed value into the symbolic variable + // The pattern uses the symbolic var which will be defined by this MatchCast + PrimStructInfo match_sinfo(symbolic_var); + std::string match_var_name = "_match" + std::to_string(symbolic_var_counter_ - 1); + Var match_cast_var(match_var_name, match_sinfo); + builder_->EmitNormalized(MatchCast(match_cast_var, prim_var, match_sinfo)); + } + /*! * \brief Create a fresh symbolic TIR variable */ @@ -223,6 +262,9 @@ class ShapeExprCanonicalizer : public ExprMutator { // Cache to avoid creating duplicate bindings for the same compound expression std::unordered_map compound_expr_to_var_; + // Track which compound expressions have had their bindings emitted + std::unordered_set emitted_bindings_; + // Counter for generating unique symbolic variable names int symbolic_var_counter_ = 0; };