Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Jan 14, 2026

Description

  1. Fixes incorrect implementation of no_torch_dynamo decorator, which results in errors for newest PyTorch. The decorator was not correctly disabled during export to onnx.
  2. Adds support for FP8 attention export.

Fixes #2588

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 7 commits January 14, 2026 19:46
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review January 27, 2026 17:09
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 27, 2026

Greptile Overview

Greptile Summary

This PR fixes ONNX export issues and adds FP8 attention support for ONNX export.

Key Changes:

  • Fixed no_torch_dynamo decorator bug: The previous lambda-based implementation incorrectly evaluated is_in_onnx_export_mode() at decoration time instead of at runtime, causing errors with newer PyTorch versions during ONNX export. The new implementation uses a proper wrapper function that checks the export mode at each call.

  • Added FP8 attention ONNX export: Implemented onnx_forward method in FP8EmulationFunc that uses ONNX-compatible operations (flatten, concatenate, quantize, split) to emulate FP8 quantization during export.

  • Enabled FP8 emulation during ONNX export: Modified attention backend selection logic to allow FP8 emulation when is_in_onnx_export_mode() returns true, even without the environment variable.

  • Updated tests: Added parameterization for FP8 recipes (DelayedScaling and Float8CurrentScaling with fp8_dpa=True), removed attention_dropout=0.5 to avoid non-deterministic outputs, and adjusted tolerance to 5e-1 for FP8 tests.

  • Updated CI: Set NVTE_UnfusedDPA_Emulate_FP8=1 environment variable in test script to enable FP8 emulation in CI environments without native FP8 hardware.

The implementation follows best practices for ONNX export by using operations with defined ONNX translations and properly handling the export mode detection.

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations
  • The changes correctly address the no_torch_dynamo decorator bug that was causing issues with newer PyTorch versions, and properly implement FP8 attention ONNX export. The implementation follows best practices by using ONNX-compatible operations and runtime mode detection. Minor consideration: assert statement in onnx_forward could be improved with better error handling.
  • Pay attention to transformer_engine/pytorch/attention/dot_product_attention/backends.py - the assert statement could cause issues if non-Float8 quantizers are passed during ONNX export

Important Files Changed

Filename Overview
transformer_engine/pytorch/jit.py Fixed no_torch_dynamo decorator to properly check ONNX export mode at runtime instead of at decoration time, preventing errors with newer PyTorch versions
transformer_engine/pytorch/attention/dot_product_attention/backends.py Added onnx_forward method to FP8EmulationFunc for ONNX-compatible FP8 quantization/dequantization using flatten+concat+quantize+split operations
tests/pytorch/test_onnx_export.py Added FP8 recipe parameterization to core attention tests, removed attention_dropout=0.5 parameter, and adjusted tolerance for FP8 tests

Sequence Diagram

sequenceDiagram
    participant User
    participant DotProductAttention
    participant AttentionBackend as get_attention_backend
    participant UnfusedDPA as UnfusedDotProductAttention
    participant FP8Emulation as FP8EmulationFunc
    participant Quantizer as Float8Quantizer
    participant ONNX as ONNX Export

    User->>DotProductAttention: forward(query, key, value)
    DotProductAttention->>AttentionBackend: get_attention_backend()
    
    alt ONNX Export Mode
        AttentionBackend->>AttentionBackend: is_in_onnx_export_mode() == True
        AttentionBackend->>AttentionBackend: allow_emulation = True
        AttentionBackend-->>DotProductAttention: use UnfusedDotProductAttention
        
        DotProductAttention->>UnfusedDPA: forward(Q, K, V)
        UnfusedDPA->>FP8Emulation: apply(Q, K, V, quantizer, "QKV_quantizer")
        FP8Emulation->>FP8Emulation: is_in_onnx_export_mode() == True
        FP8Emulation->>FP8Emulation: onnx_forward()
        
        FP8Emulation->>FP8Emulation: flatten Q, K, V
        FP8Emulation->>FP8Emulation: concatenate tensors
        FP8Emulation->>Quantizer: onnx_quantize(combined)
        Quantizer-->>FP8Emulation: FP8 tensor
        FP8Emulation->>Quantizer: onnx_dequantize(fp8_tensor)
        Quantizer-->>FP8Emulation: dequantized tensor
        FP8Emulation->>FP8Emulation: split and reshape Q, K, V
        FP8Emulation-->>UnfusedDPA: emulated FP8 Q, K, V
        
        UnfusedDPA->>UnfusedDPA: compute attention(Q, K, V)
        UnfusedDPA-->>DotProductAttention: attention output
        DotProductAttention->>ONNX: export to ONNX graph
        ONNX-->>User: ONNX model with FP8 attention
    else Normal Training/Inference
        AttentionBackend->>AttentionBackend: check env var NVTE_UnfusedDPA_Emulate_FP8
        AttentionBackend-->>DotProductAttention: select appropriate backend
        DotProductAttention-->>User: attention output
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

pggPL and others added 2 commits January 27, 2026 17:39
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator Author

pggPL commented Jan 27, 2026

/te-ci pytorch L1

timmoon10
timmoon10 previously approved these changes Jan 27, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +223 to +226
# Flatten and concatenate
combined = torch.cat(
[tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for FP8 attention, although we'll need to revisit whenever we support MXFP8 or NVFP4. Why can't we concatenate the 2D tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this will work for all layouts and different max_q_length and max_kv_length. Added asserions that's it not mxfp8, because I want to merge it fast. I will rethink it when adding support for mxfp8.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +217 to +219
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
), "ONNX FP8 emulation path supports only Float8 quantizers."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert statement will cause ONNX export to fail if non-Float8 quantizers are used. Consider replacing with a runtime check that raises a more descriptive error or logging a warning.

@pggPL
Copy link
Collaborator Author

pggPL commented Jan 27, 2026

/te-ci pytorch L1

@pggPL pggPL merged commit f04b094 into NVIDIA:main Jan 28, 2026
27 of 31 checks passed
KshitijLakhani pushed a commit that referenced this pull request Jan 28, 2026
* jjit bug fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix'

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fixes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lint fixes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Multihead Attention fails fp8 ONNX export

3 participants