Skip to content

Comments

Added column-major storage of weights and scales in INT4 quantization for model load time improvement in TRT-RTX#811

Merged
hthadicherla merged 5 commits intomainfrom
hthadicherla/column_major_trt_rtx
Feb 2, 2026
Merged

Added column-major storage of weights and scales in INT4 quantization for model load time improvement in TRT-RTX#811
hthadicherla merged 5 commits intomainfrom
hthadicherla/column_major_trt_rtx

Conversation

@hthadicherla
Copy link
Contributor

@hthadicherla hthadicherla commented Jan 23, 2026

What does this PR do?

Type of change: ? New feature

Overview:
TensorRT-RTX requires the weights and scales in the ONNX models to be in column-major format. So whenever the model loads TRT-RTX JIT transposes the weights and scales during load time, causing increased load time.

Proposed feature is after quantization, transpose the weights and scales in DQ node and add a transpose node right after i.e,
A × B = A × ((Bᵀ)ᵀ)

The transformation is post processing step and is disabled by default. It can be enabled by quantizing with --use_column_major

Usage

python -m modelopt.onnx.quantization --onnx_path "model.onnx" --output_path "model_quant.onnx" --quantize_mode int4 --calibration_method awq_lite --use_column_major --skip_shared_constants_duplication

Testing

Tested a few LLM's and their MMLU scores with and without this transformation. No degradations were observed.

Summary by CodeRabbit

  • New Features

    • Added --use_column_major CLI flag to enable column-major weight storage optimization (applies to DQ-only quantization paths).
  • Documentation

    • CLI docs updated to describe the new flag and its applicability.
  • Tests

    • New unit tests validating column-major transformation behavior and output equivalence.

…improvement in TRT-RTX

Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
@hthadicherla hthadicherla requested review from a team as code owners January 23, 2026 11:42
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 23, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • ✅ Full review completed - (🔄 Check again to review again)
📝 Walkthrough

Walkthrough

This PR introduces a column-major storage optimization feature for ONNX INT4 quantization targeting the NvTensorRtRtx execution provider. It adds a CLI flag to the quantization script, integrates it through the quantization pipeline, and provides utility functions for applying column-major transformations to GEMM weights and inserting transpose operations in DQ-only quantization modes.

Changes

Cohort / File(s) Summary
CLI & API Integration
examples/windows/onnx_ptq/genai_llm/quantize.py, modelopt/onnx/quantization/int4.py
Adds --use_column_major CLI argument and threads it through quantization function signature. Integrates flag handling into quantize_rtn, quantize, _quantize_awq_clip, and _quantize_awq_lite pathways. When enabled, branches control flow to apply column-major transformation to GEMM weights prior to DQ node creation. Flag is logged and guarded to avoid usage in incompatible modes (e.g., QDQ mode).
Transformation Utilities
modelopt/onnx/quantization/qdq_utils.py
Adds three new public functions: _apply_transpose_perm_to_shape() for computing transposed shapes, apply_column_major_transformation() to transpose quantized weights/scales in-place and return DQ attributes with axis set to 1, and add_transpose_nodes_for_column_major() to conditionally insert Transpose nodes after DQ nodes feeding MatMul/Gemm and update graph connections. Includes safeguards to skip already-processed nodes and avoid altering Gemm when transB is set.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant CLI as quantize.py<br/>(CLI)
    participant API as int4.py<br/>(quantize)
    participant Transform as qdq_utils.py<br/>(apply_column_major)
    participant Graph as Graph<br/>(ONNX)
    
    User->>CLI: --use_column_major flag
    CLI->>API: quantize(...,<br/>use_column_major=True)
    API->>Transform: apply_column_major_transformation(<br/>weights, scales, ...)
    Transform->>Transform: Transpose weights &<br/>scales in-place
    Transform->>API: Return DQ attributes<br/>(axis=1)
    API->>Graph: Create DQ nodes with<br/>column-major attributes
    API->>Transform: add_transpose_nodes_for_column_major(graph)
    Transform->>Graph: Insert Transpose nodes<br/>after DQ nodes
    Transform->>Graph: Update MatMul/Gemm<br/>inputs
    Graph-->>User: Optimized ONNX model
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title clearly and specifically describes the main change: adding column-major storage of weights and scales in INT4 quantization for TRT-RTX model load time improvement. It directly summarizes the primary objective and is well-suited to the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 90.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch hthadicherla/column_major_trt_rtx

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@modelopt/onnx/quantization/qdq_utils.py`:
- Around line 1083-1091: The current Gemm handling in
apply_column_major_transformation (qdq_utils.py) skips nodes with node.op ==
"Gemm" when node.attrs contains transB=1, which breaks semantics for
column-major weights; instead, when encountering a Gemm with transB set, update
the node.attrs transB to 0 (or remove/normalize it to zero) so the graph expects
B^T (matching the earlier weight transpose) and do not skip inserting the
transpose-back; locate the Gemm handling block (check for node.op == "Gemm" and
the transB logic) and replace the early continue with logic that flips
node.attrs["transB"] to 0 (or deletes the attr) so outputs remain correct while
keeping the transpose-back insertion.

Comment on lines 1083 to 1091
# For Gemm nodes, check if transB is already set
if node.op == "Gemm":
trans_b = False
if hasattr(node, "attrs") and "transB" in node.attrs:
trans_b = node.attrs["transB"] > 0
if trans_b:
logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
continue

Copy link
Contributor

@coderabbitai coderabbitai bot Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Gemm transB=1 skip breaks correctness with column‑major weights.
apply_column_major_transformation already transposes weights. If a Gemm has transB=1, skipping the transpose‑back makes Gemm consume B instead of B^T, changing outputs. Either always insert the transpose‑back or flip transB to 0 so Gemm consumes B^T directly.

🐛 Proposed fix (flip transB to 0 and keep semantics)
-            # For Gemm nodes, check if transB is already set
-            if node.op == "Gemm":
-                trans_b = False
-                if hasattr(node, "attrs") and "transB" in node.attrs:
-                    trans_b = node.attrs["transB"] > 0
-                if trans_b:
-                    logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
-                    continue
+            # For Gemm nodes with transB=1, flip to 0 since weights are already transposed
+            if node.op == "Gemm":
+                trans_b = bool((node.attrs or {}).get("transB", 0))
+                if trans_b:
+                    node.attrs = node.attrs or {}
+                    node.attrs["transB"] = 0
+                    logger.debug(
+                        f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
+                    )
+                    continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# For Gemm nodes, check if transB is already set
if node.op == "Gemm":
trans_b = False
if hasattr(node, "attrs") and "transB" in node.attrs:
trans_b = node.attrs["transB"] > 0
if trans_b:
logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
continue
# For Gemm nodes with transB=1, flip to 0 since weights are already transposed
if node.op == "Gemm":
trans_b = bool((node.attrs or {}).get("transB", 0))
if trans_b:
node.attrs = node.attrs or {}
node.attrs["transB"] = 0
logger.debug(
f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
)
continue
🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/qdq_utils.py` around lines 1083 - 1091, The
current Gemm handling in apply_column_major_transformation (qdq_utils.py) skips
nodes with node.op == "Gemm" when node.attrs contains transB=1, which breaks
semantics for column-major weights; instead, when encountering a Gemm with
transB set, update the node.attrs transB to 0 (or remove/normalize it to zero)
so the graph expects B^T (matching the earlier weight transpose) and do not skip
inserting the transpose-back; locate the Gemm handling block (check for node.op
== "Gemm" and the transB logic) and replace the early continue with logic that
flips node.attrs["transB"] to 0 (or deletes the attr) so outputs remain correct
while keeping the transpose-back insertion.

Copy link
Contributor

@vishalpandya1990 vishalpandya1990 Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hthadicherla Please review this and close/resolve as needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for Gemm nodes with transB attribute in ONNX files and code
echo "=== Searching for Gemm with transB in Python code ==="
rg -n "transB" --type=py -C2

echo -e "\n=== Searching for test files related to column_major ==="
fd -t f -e py test | xargs rg -l "column_major" 2>/dev/null || echo "No test files found with column_major"

echo -e "\n=== Searching for Gemm usage in test files ==="
fd -t f -e py test | xargs rg -n "Gemm" -C1 2>/dev/null | head -50 || echo "No Gemm usage in tests"

Length of output: 3647


🏁 Script executed:

#!/bin/bash
# Look at the column_major test to see what cases are covered
echo "=== Column major test content ==="
rg -n "column_major" tests/unit/onnx/test_qdq_utils.py -A 30 -B 5

echo -e "\n=== Check if any test creates Gemm with transB ==="
rg -n "transB" tests/ --type=py -C3

Length of output: 5042


@vishalpandya1990, after reviewing the code and test coverage, I can confirm this is a valid correctness bug that should be fixed.

Analysis

The current code has a logic error when handling Gemm nodes with transB=1:

Current buggy behavior:

  1. apply_column_major_transformation transposes weights: W → W^T
  2. DQ node outputs the transposed weight W^T
  3. Code skips inserting transpose-back node for Gemm with transB=1
  4. Gemm with transB=1 receives W^T and applies transpose: (W^T)^T = W
  5. Result: A @ W (incorrect! Original was A @ W^T)

With the proposed fix:

  1. Weights are transposed: W → W^T
  2. DQ outputs W^T
  3. Code flips transB from 1 to 0 and skips transpose node
  4. Gemm with transB=0 receives W^T and uses it directly
  5. Result: A @ W^T (correct! Matches original semantics)

Test Coverage Gap

The existing test in test_qdq_utils.py::TestColumnMajorTransformation only covers MatMul nodes, not Gemm with transB=1, which means this edge case is currently untested.

Recommendation

Accept the CodeRabbit suggestion - the fix is correct and preserves semantics by flipping transB to 0 when weights are already transposed.

However, I'd recommend one clarification to the log message in the fix:

# For Gemm nodes with transB=1, flip to 0 since weights are already transposed
if node.op == "Gemm":
    trans_b = bool((node.attrs or {}).get("transB", 0))
    if trans_b:
        node.attrs = node.attrs or {}
        node.attrs["transB"] = 0
        logger.debug(
            f"Gemm node {node.name}: setting transB=0 to preserve semantics with column-major weights"
        )
        continue

The continue is intentional - it skips transpose node insertion because flipping transB achieves the same result more efficiently.


🧬 Code Graph Analysis Results

File: modelopt/onnx/quantization/graph_utils.py (relevant snippet)

Line range: 283-302

def get_tensor_consumer_nodes(
    graph: onnx.GraphProto,
) -> dict[str, list[onnx.NodeProto]]:
    """Returns a dictionary of tensor name and their consumer node object mapping.

    Args:
        graph: ONNX model graph.

    Returns:
        Dictionary, key is tensor name and value is their consumer node object
    """
    # Create a dictionary to store tensor consumer nodes
    tensor_consumers = defaultdict(list)

    # Traverse the graph to find consumer nodes for each tensor
    for node in graph.node:
        for input_name in node.input:
            tensor_consumers[input_name].append(node)

    return tensor_consumers

Line range: 216-234

def get_tensor_from_name(graph: onnx.GraphProto, tensor_name: str) -> onnx.ValueInfoProto | None:
    """Returns a ValueInfoProto given a tensor name.

    Args:
        graph: ONNX model graph
        tensor_name: String with tensor name.

    Returns:
        onnx.ValueInfoProto: actual graph tensor.
    """
    # Search in inputs
    vi = next((vi for vi in graph.input if vi.name == tensor_name), None)
    # If not found, search in outputs
    if vi is None:
        vi = next((vi for vi in graph.output if vi.name == tensor_name), None)
    # If not found, search in value_info (intermediate tensors)
    if vi is None:
        vi = next((vi for vi in graph.value_info if vi.name == tensor_name), None)
    return vi

Line range: 237-280

def get_tensor_producer_nodes(
    graph: onnx.GraphProto,
    get_initializer_producers: bool = False,
) -> dict[str, onnx.NodeProto]:
    """Returns a dictionary of tensor name and their producer node object mapping.

    Note. we create a special Root type node as external inputs producer for ease of implementation.

    Args:
        graph: ONNX model graph.

    Returns:
        Dictionary, key is tensor name and value is their producer node object
    """
    # Create a dictionary to store tensor producer nodes
    tensor_producers = defaultdict(None)

    # Special Root type producer node
    root_node = onnx.helper.make_node(
        op_type="Root",
        inputs=[],
        outputs=[i.name for i in graph.input],
        name="root_0",
    )

    input_names = [graph_input.name for graph_input in graph.input]
    initializer_names = [initializer.name for initializer in graph.initializer]
    external_input_names = list(np.setdiff1d(input_names, initializer_names))

    # Note. We are marking external inputs as non-constant by adding a parent,
    # so that we can quantize the first node of the graph if appropriate
    for graph_input in external_input_names:
        tensor_producers[graph_input] = root_node

    # Traverse the graph to find producer nodes for each tensor
    for node in graph.node:
        for output_name in node.output:
            tensor_producers[output_name] = node

    if get_initializer_producers:
        for initializer in graph.initializer:
            tensor_producers[initializer.name] = initializer

    return tensor_producers

File: modelopt/onnx/quantization/graph_utils.py (relevant snippet)

Line range: 1800-1882

def remove_redundant_cast_nodes(graph: onnx.GraphProto) -> None:
    """Remove redundant Cast nodes from the ONNX graph to optimize model performance.

    This function identifies and removes two types of redundant Cast nodes:

    1. Cast nodes where input and output types are identical
       - Before: t1 (dtype=fp16) -> cast (to=fp16) -> t2 -> Op
       - After:  t1 (dtype=fp16) -> Op

    2. Cast nodes that can be fused with initializers
       - Before: (initializer) t1 (dtype=fp32) -> cast (to=fp16) -> t2 -> Op
       - After:  (initializer) t1 (dtype=fp16) -> Op

    The function preserves Cast nodes that:
    - Have outputs that are graph outputs
    - Are necessary for type conversion
    - Have dynamic inputs (not initializers)

    Args:
        graph: ONNX graph to optimize. The graph will be modified in-place.

    Note:
        - This optimization is particularly useful for models with many Cast operations
        - The function modifies the graph in-place
        - All tensor consumers are updated to maintain graph connectivity
        - Initializer data types are converted when possible to eliminate Cast nodes
    """
    initializers = {init.name: init for init in graph.initializer}
    tensor_consumers = get_tensor_consumer_nodes(graph)
    value_info_map = {info.name: info for info in graph.value_info}
    cast_indices = []
    output_names = {output.name for output in graph.output}

    def _get_tensor_type(tensor_name: str) -> int | None:
        """Get the tensor type for a given tensor name."""
        if tensor_name in value_info_map:
            return value_info_map[tensor_name].type.tensor_type.elem_type
        if tensor_name in initializers:
            return initializers[tensor_name].data_type
        return None

    for node_idx, node in enumerate(graph.node):
        if node.op_type != "Cast":
            continue

        # Skip if output is a graph output
        if any(out_name in output_names for out_name in node.output):
            continue

        input_name = node.input[0]
        input_type = _get_tensor_type(input_name)
        if input_type is None:
            continue

        # Get target type from Cast node attributes
        attr = next((attr for attr in node.attribute if attr.name == "to"), None)
        if attr is None:
            continue

        # Pattern 1: Input and output types are the same
        if input_type == attr.i:
            cast_indices.append(node_idx)
        # Pattern 2: Convert and fuse Cast node for initializers
        elif input_name in initializers:
            cast_indices.append(node_idx)
            cast_input = onnx.numpy_helper.to_array(initializers[input_name])
            dtype = onnx.helper.tensor_dtype_to_np_dtype(attr.i)
            converted_tensor = onnx.numpy_helper.from_array(cast_input.astype(dtype), input_name)
            initializers[input_name].CopyFrom(converted_tensor)
        else:
            continue

        # Update consumer nodes
        for consumer in tensor_consumers.get(node.output[0], []):
            for i, input_tensor in enumerate(consumer.input):
                if input_tensor == node.output[0]:
                    consumer.input[i] = input_name
                    break

    # Remove Cast nodes in reverse order
    logger.info(f"Removing {len(cast_indices)} redundant Cast nodes")
    for node_idx in sorted(cast_indices, reverse=True):
        del graph.node[node_idx]

File: modelopt/onnx/quantization/quant_utils.py (relevant snippet)

Line range: 189-204

def get_num_bits(layer_info: dict[str, dict] | None = None, name: str | None = None) -> int:
    """Determine the layer configuration for quantization from layer_info.

    Args:
        layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names
            to layer configuration dict.
        name (str | None): Name of the tensor.

    Returns:
        int: Number of bits to use for quantization. Defaults to 4 if not specified.
    """
    if layer_info and name in layer_info:
        num_bits = layer_info[name]["precision"]
    else:
        num_bits = 4
    return num_bits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the fix for this case. Basically for this case the transB needs to be 0 since the gemm nodes expects the the matmul in A @ B^T . With transB = 1, it takes B as input and internally transposes to B^T, for column major since we are transposing anyway we can just keep it transB = 0, in this case no transpose node is added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

@codecov
Copy link

codecov bot commented Jan 23, 2026

Codecov Report

❌ Patch coverage is 75.71429% with 17 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.38%. Comparing base (4f4558a) to head (c45e7a1).
⚠️ Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/quantization/int4.py 53.84% 12 Missing ⚠️
modelopt/onnx/quantization/qdq_utils.py 88.63% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #811      +/-   ##
==========================================
- Coverage   74.17%   73.38%   -0.79%     
==========================================
  Files         192      193       +1     
  Lines       19246    19960     +714     
==========================================
+ Hits        14276    14648     +372     
- Misses       4970     5312     +342     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@galagam
Copy link
Contributor

galagam commented Jan 25, 2026

@tcherckez-nvidia - do you mind reviewing?

@vishalpandya1990
Copy link
Contributor

vishalpandya1990 commented Jan 28, 2026

Please add a unit test for this.

For reference: https://github.com/NVIDIA/Model-Optimizer/blob/main/tests/unit/onnx

Besides, also check/compare create_test_model_with_int4_dq_reshape_transpose_matmul() in https://github.com/NVIDIA/Model-Optimizer/blob/main/tests/unit/onnx/test_qdq_utils.py.

@hthadicherla
Copy link
Contributor Author

Please add a unit test for this.

For reference: https://github.com/NVIDIA/Model-Optimizer/blob/main/tests/unit/onnx

Besides, also check/compare create_test_model_with_int4_dq_reshape_transpose_matmul() in https://github.com/NVIDIA/Model-Optimizer/blob/main/tests/unit/onnx/test_qdq_utils.py.

The pattern in the test case you mentioned seems to be DequantizeLinear -> Reshape -> Transpose -> MatMul . I'm not sure why this is being tested, i saw that reshape and transpose nodes are being removed by int4quantexporter later anyway. See https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/onnx/export/int4_exporter.py#L33-L121

regardless it is different from our pattern which is DequantizeLinear(W^T) -> Transpose ->Matmul. But what would the test case be though we create the pattern and then what ? One test case i'm thinking of is have dummy weight values and activation/layernorm values and create DequantizeLinear(W^T)->Transpose->Matmul pattern and DequantizeLinear(W) ->Matmul and see if the matmul output is the same or not.

…viders that need it), added use_column_major to log output and README, and renamed add_transpose_nodes_for_column_major to insert_transpose_nodes_for_column_major with inline comments.

Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
@hthadicherla
Copy link
Contributor Author

@vishalpandya1990 I addressed most of the comments , can you look at the new changes i made and also look at some of the questions that i had regarding some of the changes you suggested ?

@vishalpandya1990
Copy link
Contributor

But what would the test case be though we create the pattern and then what ? One test case i'm thinking of is have dummy weight values and activation/layernorm values and create DequantizeLinear(W^T)->Transpose->Matmul pattern and DequantizeLinear(W) ->Matmul and see if the matmul output is the same or not.

Yes, we can check that quantized model resulting after this transformation is enabled - is valid and as we would expect. For instance, we can do sanity check on quantized graph / nodes (layout, shapes) and the output (if feasible).

You can also skim through some existing unit tests to get further insight on potential test-cases.

Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
Verifies both produce the same output for the same input.
"""
import onnxruntime as ort

Copy link
Contributor

@vishalpandya1990 vishalpandya1990 Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, we can simplify this a bit by: creating a simple 1 linear (matmul) model, running quantize API 2 times (one with column-major on and another with column-major off). And then compare/validate output1 and output2, with a utility for model's inference run. Can be done in follow-up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do this is a follow up PR then. For now will leave this as it is, since the main goal was to test the column major functions anyway.

…sB=1 and added test to verify

Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@examples/windows/onnx_ptq/genai_llm/README.md`:
- Line 63: The table row for the `--use_column_major` option is missing a
trailing space before the closing pipe which triggers markdownlint MD060; update
the row in README.md (the line containing "`--use_column_major` | Default:
disabled | Apply column-major storage optimization for execution providers that
need it. Only applicable for DQ-only quantization.|") by adding a single space
before the final `|` so the cell delimiters are properly spaced.

In `@modelopt/onnx/quantization/qdq_utils.py`:
- Around line 1045-1135: The loop in insert_transpose_nodes_for_column_major
incorrectly skips additional MatMul/Gemm consumers because dq_nodes_processed
causes early continue; change to reuse a single Transpose per DequantizeLinear
producer by replacing dq_nodes_processed set with a dict mapping
producer_node.name -> (transpose_out, transpose_node), so for each consumer if
producer_node.name in the map just set node.inputs[1] = transpose_out (and still
flip Gemm transB when applicable) otherwise create transpose_out/transpose_node
using _apply_transpose_perm_to_shape and gs.Variable/gs.Node as currently done,
store them in the map, append the new transpose_node to nodes_to_add once, and
finally extend graph.nodes with nodes_to_add and cleanup/toposort as before.

| `--layers_8bit` | Default: None | Use this option to Overrides default mixed quant strategy|
| `--gather_quantize_axis` | Default: None | Use this option to enable INT4 quantization of Gather nodes - choose 0 or 1|
| `--gather_block_size` | Default: 32 | Block-size for Gather node's INT4 quantization (when its enabled using gather_quantize_axis option)|
| `--use_column_major` | Default: disabled | Apply column-major storage optimization for execution providers that need it. Only applicable for DQ-only quantization.|
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix table row spacing to satisfy markdownlint.

Line 63 is missing a space before the closing pipe, triggering MD060.

Proposed fix
-| `--use_column_major` | Default: disabled | Apply column-major storage optimization for execution providers that need it. Only applicable for DQ-only quantization.|
+| `--use_column_major` | Default: disabled | Apply column-major storage optimization for execution providers that need it. Only applicable for DQ-only quantization. |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| `--use_column_major` | Default: disabled | Apply column-major storage optimization for execution providers that need it. Only applicable for DQ-only quantization.|
| `--use_column_major` | Default: disabled | Apply column-major storage optimization for execution providers that need it. Only applicable for DQ-only quantization. |
🧰 Tools
🪛 markdownlint-cli2 (0.20.0)

[warning] 63-63: Table column style
Table pipe is missing space to the left for style "compact"

(MD060, table-column-style)

🤖 Prompt for AI Agents
In `@examples/windows/onnx_ptq/genai_llm/README.md` at line 63, The table row for
the `--use_column_major` option is missing a trailing space before the closing
pipe which triggers markdownlint MD060; update the row in README.md (the line
containing "`--use_column_major` | Default: disabled | Apply column-major
storage optimization for execution providers that need it. Only applicable for
DQ-only quantization.|") by adding a single space before the final `|` so the
cell delimiters are properly spaced.

Comment on lines +1045 to +1135
def insert_transpose_nodes_for_column_major(graph: gs.Graph):
"""Add a single Transpose node after each DequantizeLinear for column-major weights.

This implements the simple transformation: A @ B = A @ ((B^T)^T)
where B^T is stored in the DequantizeLinear node, and we add a Transpose
node after DQ to recover B before the MatMul.

Graph transformation:
Before: DQ(W) -> MatMul/Gemm
After: DQ(W^T) -> Transpose -> W -> MatMul/Gemm

Args:
graph: ONNX GraphSurgeon graph to modify in-place
"""
nodes_to_add = []
dq_nodes_processed = set()

for node in graph.nodes:
if node.op in ["MatMul", "Gemm"]:
# Check if second input (weight) is from DequantizeLinear
weight_input = node.inputs[1]
if not isinstance(weight_input, gs.Variable):
continue

# Find the producer of the weight input
producer_nodes = [n for n in graph.nodes if weight_input in n.outputs]
if not producer_nodes:
continue

producer_node = producer_nodes[0]
if producer_node.op != DEQUANTIZE_NODE_NAME:
continue

# Skip if we already processed this DQ node
if producer_node.name in dq_nodes_processed:
continue
dq_nodes_processed.add(producer_node.name)

# For Gemm nodes with transB=1, flip to transB=0 since weights are already transposed
# Original: Gemm expects W and internally computes A @ W^T
# After column-major: weight is W^T, so set transB=0 to use W^T directly -> A @ W^T
if node.op == "Gemm":
if hasattr(node, "attrs") and "transB" in node.attrs and node.attrs["transB"] > 0:
logger.debug(
f"Gemm node {node.name} has transB=1, flipping to transB=0 for column-major"
)
node.attrs["transB"] = 0
continue

# Get weight shape and dtype from DQ output
# DQ outputs W^T (transposed), shape is [N, K] instead of [K, N]
weight_shape = weight_input.shape if hasattr(weight_input, "shape") else None
weight_dtype = weight_input.dtype if hasattr(weight_input, "dtype") else None

# Permutation for 2D weights: [1, 0] to transpose back
# The stored weight is B^T (transposed), we need to get B back
# For 2D [N, K] (stored as transposed): perm [1, 0] -> [K, N] (original)
perm = [1, 0]

# Compute the transposed shape (original weight shape)
transposed_weight_shape = _apply_transpose_perm_to_shape(weight_shape, perm)

# Create output variable for the transpose node
transpose_out = gs.Variable(
f"{producer_node.name}_transposed_back",
dtype=weight_dtype,
shape=transposed_weight_shape,
)

# Create transpose node: (B^T)^T = B
transpose_node = gs.Node(
op="Transpose",
name=f"{producer_node.name}_transpose_back",
inputs=[weight_input],
outputs=[transpose_out],
attrs={"perm": perm},
)

# Update MatMul/Gemm to use the transposed weight
node.inputs[1] = transpose_out

# Add transpose node to list
nodes_to_add.append(transpose_node)

# Add all new nodes to graph
if nodes_to_add:
graph.nodes.extend(nodes_to_add)
logger.info(f"Added {len(nodes_to_add)} transpose nodes for column-major optimization")

# Clean up and reorder graph
graph.cleanup().toposort()
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle multiple consumers of the same DQ output.

Line 1078 skips processing once a DQ node is marked, so additional MatMul/Gemm consumers won’t be rewired (and won’t get transB flipped when applicable). This breaks correctness if weights are shared.

Proposed fix (reuse a single transpose per DQ output)
-    dq_nodes_processed = set()
+    transpose_out_by_dq: dict[str, gs.Variable] = {}
@@
-            # Skip if we already processed this DQ node
-            if producer_node.name in dq_nodes_processed:
-                continue
-            dq_nodes_processed.add(producer_node.name)
-
             # For Gemm nodes with transB=1, flip to transB=0 since weights are already transposed
             # Original: Gemm expects W and internally computes A @ W^T
             # After column-major: weight is W^T, so set transB=0 to use W^T directly -> A @ W^T
             if node.op == "Gemm":
                 if hasattr(node, "attrs") and "transB" in node.attrs and node.attrs["transB"] > 0:
                     logger.debug(
                         f"Gemm node {node.name} has transB=1, flipping to transB=0 for column-major"
                     )
                     node.attrs["transB"] = 0
                     continue
+
+            existing_transpose_out = transpose_out_by_dq.get(producer_node.name)
+            if existing_transpose_out is not None:
+                node.inputs[1] = existing_transpose_out
+                continue
@@
             # Add transpose node to list
             nodes_to_add.append(transpose_node)
+            transpose_out_by_dq[producer_node.name] = transpose_out
🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/qdq_utils.py` around lines 1045 - 1135, The loop
in insert_transpose_nodes_for_column_major incorrectly skips additional
MatMul/Gemm consumers because dq_nodes_processed causes early continue; change
to reuse a single Transpose per DequantizeLinear producer by replacing
dq_nodes_processed set with a dict mapping producer_node.name -> (transpose_out,
transpose_node), so for each consumer if producer_node.name in the map just set
node.inputs[1] = transpose_out (and still flip Gemm transB when applicable)
otherwise create transpose_out/transpose_node using
_apply_transpose_perm_to_shape and gs.Variable/gs.Node as currently done, store
them in the map, append the new transpose_node to nodes_to_add once, and finally
extend graph.nodes with nodes_to_add and cleanup/toposort as before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai, each matmul node is quantized seperately and have seperate DQ nodes , so the weights won't be shared (2 matmul nodes won't have same DQ node). Is your query still valid considering this ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hthadicherla, thank you for the clarification! You're right that if each MatMul/Gemm node has its own dedicated DQ node (no weight sharing), the scenario I flagged wouldn't occur in your specific use case.

However, I'd like to point out a few considerations:

  1. General utility function: Since this is in qdq_utils.py, it's a utility function that could potentially be used in other contexts or model architectures where weight sharing might occur. ONNX does support weight reuse (multiple operators consuming the same DQ output).

  2. Edge case with Gemm transB: Even without weight sharing, there's a subtle interaction in the code flow. The dq_nodes_processed check happens before the Gemm transB handling (lines 1057-1070). If somehow a DQ node were processed for a MatMul first, then a Gemm with transB=1 using the same DQ would be skipped entirely (not just continuing after flipping transB).

  3. Defensive coding: The proposed fix would handle both scenarios (shared and non-shared weights) correctly with minimal overhead—just a dict lookup instead of a set check.

My recommendation: If you're confident that your current use cases will never have shared weights and you want to keep the code simpler, the current implementation is fine. However, if you'd like more robustness for potential future use cases or edge cases, the suggested fix would be beneficial.

What are your thoughts? Would you prefer to add the defensive handling, or keep it as-is given your use case constraints?

@hthadicherla hthadicherla merged commit fc6a211 into main Feb 2, 2026
39 of 41 checks passed
@hthadicherla hthadicherla deleted the hthadicherla/column_major_trt_rtx branch February 2, 2026 14:54
danielkorzekwa pushed a commit that referenced this pull request Feb 17, 2026
… for model load time improvement in TRT-RTX (#811)

## What does this PR do?

**Type of change:** ? New feature

**Overview:** 
TensorRT-RTX requires the weights and scales in the ONNX models to be in
column-major format. So whenever the model loads TRT-RTX JIT transposes
the weights and scales during load time, causing increased load time.

Proposed feature is after quantization, transpose the weights and scales
in DQ node and add a transpose node right after i.e,
A × B = A × ((Bᵀ)ᵀ)

The transformation is post processing step and is disabled by default.
It can be enabled by quantizing with --use_column_major

## Usage
```
python -m modelopt.onnx.quantization --onnx_path "model.onnx" --output_path "model_quant.onnx" --quantize_mode int4 --calibration_method awq_lite --use_column_major --skip_shared_constants_duplication
```

## Testing
Tested a few LLM's and their MMLU scores with and without this
transformation. No degradations were observed.



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added --use_column_major CLI flag to enable column-major weight
storage optimization (applies to DQ-only quantization paths).

* **Documentation**
  * CLI docs updated to describe the new flag and its applicability.

* **Tests**
* New unit tests validating column-major transformation behavior and
output equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants