Skip to content

Comments

add local hessian calibration#788

Merged
Fridah-nv merged 11 commits intomainfrom
fridah/local-hessian
Feb 20, 2026
Merged

add local hessian calibration#788
Fridah-nv merged 11 commits intomainfrom
fridah/local-hessian

Conversation

@Fridah-nv
Copy link
Contributor

@Fridah-nv Fridah-nv commented Jan 16, 2026

What does this PR do?

Type of change: new feature

Overview:
Add a new calibration method for weight scale search. It considers activation information by weighing scale candidates with local hessian matrix. Initial experiments with Qwen3 8B NVFP4 shows improvements.

Usage

Use NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG quantization config for quantization and evaluation.

e.g.
Add this line "nvfp4_local_hessan": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, to QUANT_CFG_CHOICES in examples/llm_ptq/hf_ptq.py

cd examples/llm_ptq
python hf_ptq.py --pyt_ckpt_path /path/to/hf/checkpoint  --qformat nvfp4_local_hessan --export_path path/to/save/quantized/checkpoint --kv_cache_qformat none --trust_remote_code 

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added local Hessian-weighted MSE calibration pathway for NVFP4 per-block quantization with configurable amax search parameters and FP8 scale sweep support.
  • Tests

    • Added test coverage for the new local Hessian weight-only quantization configuration.

@Fridah-nv Fridah-nv requested a review from realAsma January 16, 2026 00:17
@Fridah-nv Fridah-nv self-assigned this Jan 16, 2026
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 16, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 16, 2026

📝 Walkthrough

Walkthrough

This pull request introduces a new local Hessian-based calibration method for PyTorch quantization. It adds a new quantization configuration, calibration function with helper logic, mode descriptor for registry integration, and corresponding GPU tests.

Changes

Cohort / File(s) Summary
Quantization Configuration
modelopt/torch/quantization/config.py
Added NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG config and LocalHessianCalibConfig class to support local Hessian-weighted MSE calibration with configurable step size, multipliers, FP8 scale sweep, block size, and distributed sync controls.
Calibration Mode Registration
modelopt/torch/quantization/mode.py
Added LocalHessianModeDescriptor for calibration mode registry, linking the new local Hessian config to its calibration function implementation.
Calibration Algorithm Implementation
modelopt/torch/quantization/model_calib.py
Implemented local_hessian_calibrate() function with internal LocalHessianHelper to collect activations, compute per-block Hessians, and perform MSE-based calibration with activation caching and cleanup logic.
CUDA Quantization Tests
tests/gpu/torch/quantization/test_quantize_cuda.py
Added NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG test configuration and integrated it into parameterized test cases.

Sequence Diagram

sequenceDiagram
    participant Model as PyTorch Model
    participant LHC as local_hessian_calibrate()
    participant Helper as LocalHessianHelper
    participant Calib as MSE Calibrator

    Model->>LHC: Initialize with forward_loop
    LHC->>LHC: Run initial max_calibrate()
    LHC->>Helper: Create helpers for quantized modules
    LHC->>Model: Register forward hooks
    LHC->>LHC: Cache activations via forward passes
    Helper->>Helper: Compute per-block Hessians
    LHC->>Calib: Replace with local Hessian MSE calibrator
    Calib->>Calib: Search amax with Hessian-weighted loss
    LHC->>LHC: Cleanup hooks and caches
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'add local hessian calibration' accurately summarizes the main change, which introduces a new local Hessian-based calibration method across multiple modules and test files.
Docstring Coverage ✅ Passed Docstring coverage is 85.71% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fridah/local-hessian

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks great, have you run any experiments with this?

@Fridah-nv
Copy link
Contributor Author

overall looks great, have you run any experiments with this?

I'm still get the numbers, will update here when it's ready

@Fridah-nv Fridah-nv force-pushed the fridah/local-hessian branch from baef63e to b6fdc75 Compare January 23, 2026 00:37
Base automatically changed from fridah/mse-fp8-sweep to main January 26, 2026 22:39
@Fridah-nv Fridah-nv force-pushed the fridah/local-hessian branch from b6fdc75 to 4d1380a Compare January 27, 2026 00:43
@codecov
Copy link

codecov bot commented Jan 27, 2026

Codecov Report

❌ Patch coverage is 19.76744% with 138 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.03%. Comparing base (ac7c985) to head (569bbc6).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 10.96% 138 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #788      +/-   ##
==========================================
- Coverage   73.45%   73.03%   -0.43%     
==========================================
  Files         205      205              
  Lines       22034    22200     +166     
==========================================
+ Hits        16185    16213      +28     
- Misses       5849     5987     +138     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Fridah-nv Fridah-nv force-pushed the fridah/local-hessian branch from 4d1380a to c7589d1 Compare February 5, 2026 22:45
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Fridah-nv Fridah-nv force-pushed the fridah/local-hessian branch from c7589d1 to 8b5da94 Compare February 5, 2026 22:51
@Fridah-nv Fridah-nv changed the base branch from main to asma/refactor-scale-sweep February 5, 2026 22:55
@Fridah-nv Fridah-nv marked this pull request as ready for review February 5, 2026 23:01
@Fridah-nv Fridah-nv requested a review from a team as a code owner February 5, 2026 23:01
@Fridah-nv Fridah-nv requested review from meenchen and sugunav14 and removed request for a team February 5, 2026 23:01
Base automatically changed from asma/refactor-scale-sweep to main February 6, 2026 19:47
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv force-pushed the fridah/local-hessian branch from e35ebb0 to 2931f61 Compare February 6, 2026 20:03
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/config.py`:
- Around line 391-411: The config dict name NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG
does not match the tests which expect mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG;
either rename the dict to NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG or create an alias
assignment (NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG =
NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG) and ensure the config name is included in
the module's choices set (add "NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG" to the
choices collection) so tests and any selection logic can find it.

In `@modelopt/torch/quantization/model_calib.py`:
- Around line 620-623: The cleanup loop only iterates weight_quantizers_info and
misses modules that had their forward patched in setup() but were later disabled
(is_enabled set to False); update setup() (the code path that patches
is_quantized_linear modules and attaches module.local_hessian/patches forward)
to record every module whose forward is patched into a new list (e.g.,
patched_modules) or append such modules into weight_quantizers_info regardless
of is_enabled, then change the cleanup block (after setting
LocalHessianHelper.cache_mode = False) to iterate over that recorded list and
call module.local_hessian.cleanup() (and restore/unpatch the forward if needed)
to ensure all patched modules are cleaned up.
- Around line 486-497: The code in local_hessian_error creates a huge temporary
via hessian.repeat(cout,1,1); instead compute the Hessian-weighted quadratic
form without materializing the repeated tensor by leveraging broadcasting or
einsum. Replace the repeat + matrix-mult sequence (hessian_expanded =
hessian.repeat(...); block_loss = (dw @ hessian_expanded @
dw.transpose(...)).squeeze(...)) with a memory-efficient einsum or broadcasted
matmul, e.g. compute block_loss using torch.einsum('nbk,bkl,nbl->n',
dw.squeeze(1), hessian, dw.squeeze(1)) or align dims with unsqueeze on hessian
and rely on broadcasting so no hessian.repeat is created; update use sites for
local_hessian_error, hessian, dw and block_loss accordingly.

In `@tests/gpu/torch/quantization/test_quantize_cuda.py`:
- Line 90: The test references a non-existent/mismatched config name
NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG; update the test to use the correct exported
config NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG (or add an explicit alias/export in
the module where configs are defined) and remove the misleading "WEIGHT_ONLY"
wording if the config enables input quantization; locate occurrences in the test
(e.g., where mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG is used) and replace them
with mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG or add the alias in the config
package to avoid AttributeError during test collection.
🧹 Nitpick comments (2)
modelopt/torch/quantization/model_calib.py (2)

552-572: quant_func duplicates the existing _mse_quant_func (line 240).

The closure at lines 552–572 is nearly identical to the top-level _mse_quant_func. Reusing it via partial(_mse_quant_func, quantizer=weight_quantizer) (as done in mse_calibrate at line 330) would eliminate the duplication.

Reuse _mse_quant_func
-        def quant_func(x, amax, quantizer=weight_quantizer):
-            original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None
-            quantizer._amax = amax
-
-            with (
-                enable_quant(quantizer),
-                disable_calib(quantizer),
-                enable_fake_quant(quantizer),
-            ):
-                if hasattr(quantizer, "_original_shape"):
-                    x = quantizer._reset_to_original_shape(x)
-                xq = quantizer(x)
-                if hasattr(quantizer, "_block_reshape_size"):
-                    xq = xq.reshape(quantizer._block_reshape_size)
-
-            if original_amax is not None:
-                quantizer._amax = original_amax
-            else:
-                delattr(quantizer, "_amax")
-
-            return xq
+        quant_func = partial(_mse_quant_func, quantizer=weight_quantizer)

417-477: LocalHessianHelper and accumulate_hessian are well-structured.

The pattern of a nested helper class with a class-level cache_mode flag follows the established AWQLiteHelper design. Minor note: the matmul at line 475 operates in the input tensor's dtype before converting to float32 — the existing update_hessian function (line 1478) converts to float() before the matmul. On GPU with TensorCores this is fine, but for consistency you may want to cast x to float32 before the matmul.

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv enabled auto-merge (squash) February 19, 2026 23:18
@Fridah-nv Fridah-nv merged commit adcce61 into main Feb 20, 2026
48 of 50 checks passed
@Fridah-nv Fridah-nv deleted the fridah/local-hessian branch February 20, 2026 00:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants