-
Notifications
You must be signed in to change notification settings - Fork 632
[PyTorch] ONNX test fix + export for FP8 attention #2598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR fixes ONNX export issues and adds FP8 attention support for ONNX export. Key Changes:
The implementation follows best practices for ONNX export by using operations with defined ONNX translations and properly handling the export mode detection. Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant DotProductAttention
participant AttentionBackend as get_attention_backend
participant UnfusedDPA as UnfusedDotProductAttention
participant FP8Emulation as FP8EmulationFunc
participant Quantizer as Float8Quantizer
participant ONNX as ONNX Export
User->>DotProductAttention: forward(query, key, value)
DotProductAttention->>AttentionBackend: get_attention_backend()
alt ONNX Export Mode
AttentionBackend->>AttentionBackend: is_in_onnx_export_mode() == True
AttentionBackend->>AttentionBackend: allow_emulation = True
AttentionBackend-->>DotProductAttention: use UnfusedDotProductAttention
DotProductAttention->>UnfusedDPA: forward(Q, K, V)
UnfusedDPA->>FP8Emulation: apply(Q, K, V, quantizer, "QKV_quantizer")
FP8Emulation->>FP8Emulation: is_in_onnx_export_mode() == True
FP8Emulation->>FP8Emulation: onnx_forward()
FP8Emulation->>FP8Emulation: flatten Q, K, V
FP8Emulation->>FP8Emulation: concatenate tensors
FP8Emulation->>Quantizer: onnx_quantize(combined)
Quantizer-->>FP8Emulation: FP8 tensor
FP8Emulation->>Quantizer: onnx_dequantize(fp8_tensor)
Quantizer-->>FP8Emulation: dequantized tensor
FP8Emulation->>FP8Emulation: split and reshape Q, K, V
FP8Emulation-->>UnfusedDPA: emulated FP8 Q, K, V
UnfusedDPA->>UnfusedDPA: compute attention(Q, K, V)
UnfusedDPA-->>DotProductAttention: attention output
DotProductAttention->>ONNX: export to ONNX graph
ONNX-->>User: ONNX model with FP8 attention
else Normal Training/Inference
AttentionBackend->>AttentionBackend: check env var NVTE_UnfusedDPA_Emulate_FP8
AttentionBackend-->>DotProductAttention: select appropriate backend
DotProductAttention-->>User: attention output
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
|
/te-ci pytorch L1 |
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| # Flatten and concatenate | ||
| combined = torch.cat( | ||
| [tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine for FP8 attention, although we'll need to revisit whenever we support MXFP8 or NVFP4. Why can't we concatenate the 2D tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this will work for all layouts and different max_q_length and max_kv_length. Added asserions that's it not mxfp8, because I want to merge it fast. I will rethink it when adding support for mxfp8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
| assert isinstance( | ||
| quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) | ||
| ), "ONNX FP8 emulation path supports only Float8 quantizers." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assert statement will cause ONNX export to fail if non-Float8 quantizers are used. Consider replacing with a runtime check that raises a more descriptive error or logging a warning.
|
/te-ci pytorch L1 |
* jjit bug fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix' Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
no_torch_dynamodecorator, which results in errors for newest PyTorch. The decorator was not correctly disabled during export to onnx.Fixes #2588
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: