Skip to content

Conversation

@yueshen2016
Copy link
Contributor

@yueshen2016 yueshen2016 commented Jan 30, 2026

What does this PR do?

Type of change: Bug fix

Overview:

Fix te_grouped_quantized_linear_fn argument parsing for TEGroupedLinear quantization when parallelism configuration results in fewer local experts per GPU.

Problem

When running MoE models with TEGroupedLinear quantization using high expert parallelism, the forward pass fails with:
AttributeError: 'tuple' object has no attribute 'numel'

Root Cause

The original code assumed len(args) >= 2 * num_gemms + idx + 2. This assumption holds when there are many local experts per GPU, but fails when experts are highly distributed.

Taking Qwen3-30B-A3B (with num_gemms=21, threshold=44) as an example:

Parallelism Config Local Experts len(args) Threshold (44) Result
pp=2, tp=4, ep=4 More per GPU ≥ 44 ✓ Pass Works
tp=8, ep=8 Fewer per GPU 34 ✗ Fail Crashes

The bug triggers whenever len(args) < threshold, regardless of which script is running:

  • quantize.py with high EP → Fails during calibration
  • ptq_generate.py with high EP → Fails during inference

When len(args) < threshold:

  • args[-2*num_gemms:] wraps around and captures ALL elements instead of just the last ones
  • args[idx+1:-2*num_gemms] becomes empty, losing critical non_tensor_args
  • The non_tensor_args tuple gets incorrectly treated as a weight tensor

Usage

With this fix, you can use any parallelism configuration:

# High EP quantization - previously failed, now works
torchrun --nproc_per_node 8 examples/quantization/quantize.py \
  --hf-model-id /models/Qwen3-30B-A3B \
  --export-quant-cfg fp8 \
  --megatron-save-path /models/Qwen3-30B-A3B_fp8_mlm \
  --tp 8 \
  --ep 8

# High EP inference - previously failed, now works  
torchrun --nproc_per_node 8 examples/quantization/ptq_generate.py \
  --megatron-load-path /models/Qwen3-30B-A3B_fp8_mlm \
  --hf-model-id /models/Qwen3-30B-A3B \
  --tp 8 \
  --ep 8

Testing

# High EP quantization - previously failed, now works
torchrun --nproc_per_node 8 examples/quantization/quantize.py \
  --hf-model-id /models/Qwen3-30B-A3B \
  --export-quant-cfg fp8 \
  --megatron-save-path /models/Qwen3-30B-A3B_fp8_mlm \
  --tp 8 \
  --ep 8

# High EP inference - previously failed, now works  
torchrun --nproc_per_node 8 examples/quantization/ptq_generate.py \
  --megatron-load-path /models/Qwen3-30B-A3B_fp8_mlm \
  --hf-model-id /models/Qwen3-30B-A3B \
  --tp 8 \
  --ep 8

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • Bug Fixes

    • Enhanced Mixture of Experts (MoE) calibration validation and synchronization to ensure consistency across distributed training setups.
    • Improved grouped linear quantization robustness to handle varying input patterns and tensor dimensions.
  • Improvements

    • Better error handling for incomplete MoE expert calibration detection.
    • More flexible argument parsing for quantization operations.

✏️ Tip: You can customize this high-level summary in your review settings.

@yueshen2016 yueshen2016 self-assigned this Jan 30, 2026
@yueshen2016 yueshen2016 requested a review from a team as a code owner January 30, 2026 22:09
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 30, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • 🔍 Trigger a full review
📝 Walkthrough

Walkthrough

The changes refactor Mixture-of-Experts (MoE) calibration handling in PyTorch quantization across three modules. They add explicit MoE calibration validation and local expert amax synchronization in model_calib.py, remove the specialized _QuantMoELayer class from megatron.py, and improve argument parsing robustness in transformer_engine.py's grouped linear quantization path for varying input configurations.

Changes

Cohort / File(s) Summary
MoE Calibration Validation & Synchronization
modelopt/torch/quantization/model_calib.py
Introduces _has_expert_parallelism() and _check_moe_calibration_complete() functions to detect expert-parallelism and validate calibration completeness across distributed groups. Adds local expert amax synchronization in max_calibrate() before distributed sync, with validation checks ensuring calibration consistency before proceeding.
MoE Layer Quantization Removal
modelopt/torch/quantization/plugins/megatron.py
Removes unused import and deletes entire _QuantMoELayer class, eliminating specialized token-dispatch calibration handling for MoE layers during quantization.
Grouped Linear Quantization Robustness
modelopt/torch/quantization/plugins/transformer_engine.py
Reworks TE grouped linear quantization to robustly parse weights/biases from varying argument positions (tail vs. remaining_args). Introduces flexible argument splitting and reconstruction logic to handle both single-partition (ep=1) and multi-partition (ep>1) invocation patterns, improving compatibility with different argument list lengths.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately reflects the primary change: fixing TEGroupedLinear quantization to support expert parallelism when EP > 1, which is the core bug fix addressed across multiple files.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yueshen/fix-te-grouped-linear-ep-quantization

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Jan 30, 2026

Codecov Report

❌ Patch coverage is 52.17391% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.34%. Comparing base (81b67dd) to head (a85f04e).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 52.17% 11 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #833      +/-   ##
==========================================
- Coverage   73.82%   73.34%   -0.49%     
==========================================
  Files         193      193              
  Lines       19745    19913     +168     
==========================================
+ Hits        14577    14605      +28     
- Misses       5168     5308     +140     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yueshen2016 yueshen2016 force-pushed the yueshen/fix-te-grouped-linear-ep-quantization branch from 0deb9b6 to a85f04e Compare January 30, 2026 23:01
weights_and_biases = args[-2 * num_gemms :]
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
non_tensor_args = args[idx + 1]
num_gemms = len(non_tensor_args)
Copy link
Contributor

@realAsma realAsma Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_gemms = len(non_tensor_args)
num_gemms = len(non_tensor_args[0])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants