Skip to content

[#11146][feat] AutoDeploy: Add triton paged attention#11355

Draft
nvchenghaoz wants to merge 2 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/triton-paged-attention
Draft

[#11146][feat] AutoDeploy: Add triton paged attention#11355
nvchenghaoz wants to merge 2 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/triton-paged-attention

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Feb 6, 2026

This is the part 1 to add the triton paged attention as the AutoDeploy attention backend. #11146

The changes in this PR:

  1. Add the two stage (Flash Attention Style) triton implementation for paged attention
  2. Add a single stage triton implementation for paged attention (the perf for the single stage attention is not good but might be good for fast onboarding. We can delete this if the value is limited)
  3. Add the related tests for two stage triton attention.

The perf for Llama-3.1-8B-Instruct for triton_paged vs. flashinfer

image

Summary by CodeRabbit

  • New Features

    • Added two Triton-based paged attention backends with KV cache support: a two-stage implementation with FlashDecoding and a one-stage implementation for efficient attention computation.
  • Tests

    • Added comprehensive unit tests validating paged attention functionality across decode, context/prefill paths, and cache updates with comparisons against PyTorch and FlashInfer implementations.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz nvchenghaoz requested a review from a team as a code owner February 6, 2026 19:31
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

This PR removes explicit module exports from the attention package's __init__.py to rely on implicit discovery via parent-level auto-import, and adds two comprehensive Triton-based paged attention implementations (multi-stage and one-stage variants) with KV cache management, kernel-level optimizations, and comprehensive unit tests for validation.

Changes

Cohort / File(s) Summary
Attention Module API
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/__init__.py
Removes explicit __all__ exports and transitions to implicit module discovery via parent package's pkgutil.walk_packages, reducing explicit public API declarations.
Triton Paged Attention (Multi-stage)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
Implements multi-stage Triton-based paged attention with HND cache layout, including KV cache update kernel, two-stage FlashDecode (stage1/stage2), context/prefill path with causal masking, metadata preparation, and descriptor registration for integration.
Triton Paged Attention (One-stage)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py
Implements single-stage Triton-based paged attention with HND cache layout, featuring combined KV cache update, single online-softmax kernel for decode, context/prefill path, metadata preparation, and AttentionDescriptor for registry integration.
Unit Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/attention/test_triton_paged_attention.py
Comprehensive test suite validating Triton paged decode/context kernels, cache updates, and comparisons against PyTorch reference and FlashInfer implementations across varied configurations.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant MHAOp as MHA with Cache Op
    participant MetaPrep as Metadata Prep
    participant KVCache as KV Cache
    participant UpdateKernel as KV Update Kernel
    participant PrefillKernel as Prefill Kernel
    participant DecodeKernel as Decode Kernel

    Client->>MHAOp: q, k, v, cache_config
    MHAOp->>MetaPrep: position_ids, batch_info, cu_seqlen
    MetaPrep->>MetaPrep: Compute batch_indices, positions
    MetaPrep-->>MHAOp: metadata (batch_indices, positions)
    MHAOp->>UpdateKernel: k, v, batch_indices, positions, page_table
    UpdateKernel->>KVCache: Write k, v to paged blocks
    KVCache-->>UpdateKernel: Cache updated
    alt Has prefill tokens
        MHAOp->>PrefillKernel: q, kv_cache, page_indices, causal_mask
        PrefillKernel->>PrefillKernel: Online softmax over KV pages
        PrefillKernel-->>MHAOp: Prefill output
    end
    alt Has decode tokens
        MHAOp->>DecodeKernel: q, kv_cache, page_indices
        DecodeKernel->>DecodeKernel: Attention computation (multi-stage or single-stage)
        DecodeKernel-->>MHAOp: Decode output
    end
    MHAOp-->>Client: Combined attention output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description lacks the required sections from the template (Description, Test Coverage, and PR Checklist are incomplete/missing), though it provides background on the feature and includes performance data. Complete the Description, Test Coverage, and PR Checklist sections to match the template requirements for full documentation.
Docstring Coverage ⚠️ Warning Docstring coverage is 70.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically references the main change: adding Triton paged attention to AutoDeploy, matching the core implementation work in the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py`:
- Around line 662-663: The custom op declaration for
triton_paged_one_stage_mha_with_cache incorrectly specifies mutates_args=() even
though the op updates the kv cache in-place; update the decorator to include
"kv_cache" in the mutates_args tuple so PyTorch knows the op mutates that
argument (mirror the fix used in the two-stage variant and the in-place update
performed by update_paged_kv_cache).
- Around line 192-199: The autotune key list in the triton.autotune decorator
includes "MAX_SEQ_LEN", which causes frequent recompiles because max_seq_len
varies per batch (computed from kv_indices.shape[0] * page_size); remove
"MAX_SEQ_LEN" from the key array in the `@triton.autotune` call and keep only
stable keys (e.g., "HEAD_DIM" and "PAGE_SIZE" or add "HEAD_RATIO" like the
two-stage variant), then pass max_seq_len as a regular constexpr/kernel argument
(not as a tuning key) so the kernel receives it at launch but Triton does not
retune on every batch; update any call sites and kernel signature that
referenced MAX_SEQ_LEN as an autotune key to accept it as a runtime/constexpr
parameter instead.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 771-772: The custom op declaration incorrectly claims no inputs
are mutated; update the decorators to declare that kv_cache is mutated so Torch
can preserve ordering and avoid CSE issues: change
`@torch.library.custom_op`("auto_deploy::triton_paged_mha_with_cache",
mutates_args=()) to include the kv_cache argument (e.g.
mutates_args=("kv_cache",) or the appropriate positional index for kv_cache) and
apply the same change to flashinfer_mha_with_cache,
torch_backend_mha_with_cache, and triton_paged_one_stage_mha_with_cache so any
op that calls update_paged_kv_cache(...) explicitly reports the in-place
mutation.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/attention/test_triton_paged_attention.py`:
- Around line 62-66: The loop in create_page_table incorrectly computes
remaining capacity using sum(kv_indptr[: i + 1].tolist()) which double-counts
cumulative entries; replace that expression with the current cumulative value
int(kv_indptr[i].item()) so remaining pages = num_blocks -
int(kv_indptr[i].item()), and keep the rest of the logic (num_pages =
min(max_pages_per_seq, ...), pages range construction, updating kv_indptr[i +
1]) unchanged; update references to kv_indptr usage in the loop to use
int(kv_indptr[i].item()) for clarity.
- Around line 356-466: The tests import flashinfer inside
test_decode_vs_flashinfer and misuse pytest.importorskip in
test_prefill_vs_flashinfer; add a top-level availability flag (try: import
flashinfer; _HAS_FLASHINFER = True except ImportError: _HAS_FLASHINFER = False)
near the module imports, remove the inline import from
test_decode_vs_flashinfer, and annotate both test_decode_vs_flashinfer and
test_prefill_vs_flashinfer with `@pytest.mark.skipif`(not _HAS_FLASHINFER,
reason="FlashInfer not installed") so both tests are skipped when FlashInfer is
absent (referencing the test function names test_decode_vs_flashinfer and
test_prefill_vs_flashinfer and the _HAS_FLASHINFER symbol).
🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (1)

163-166: _get_num_splits hardcodes device 0 for SM count.

torch.cuda.get_device_properties(0) always queries GPU 0. On multi-GPU setups the model may run on a different device, yielding the wrong multi_processor_count and a suboptimal split count. Consider accepting the device (or a tensor to infer it from) so the correct GPU is queried.

Proposed fix
-def _get_num_splits(max_seq_len: int, batch_size: int, n_kv_heads: int, page_size: int) -> int:
+def _get_num_splits(
+    max_seq_len: int, batch_size: int, n_kv_heads: int, page_size: int, device: torch.device = torch.device("cuda", 0)
+) -> int:
     ...
-    num_sms = torch.cuda.get_device_properties(0).multi_processor_count
+    num_sms = torch.cuda.get_device_properties(device).multi_processor_count

And at the call site in triton_paged_decode:

-    num_splits = _get_num_splits(max_seq_len, batch_size, n_kv_heads, page_size)
+    num_splits = _get_num_splits(max_seq_len, batch_size, n_kv_heads, page_size, device=q.device)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py (1)

49-61: Significant code duplication with triton_paged_attention.py.

KV_LAYOUT, _get_sm_scale, _update_paged_kv_cache_kernel, update_paged_kv_cache, the metadata preparation ops, and most of the descriptor boilerplate are copied verbatim from the two-stage file. Consider extracting these shared utilities into a common module (e.g., _triton_paged_common.py) to avoid maintaining identical code in two places.

Given the one-stage variant is described as potentially being removed later, this can be deferred — but it's worth tracking to avoid future divergence.

Also applies to: 64-182

Comment on lines +192 to +199
@triton.autotune(
configs=[
triton.Config({"SEQ_BLOCK_SIZE": 32}, num_warps=4, num_stages=2),
triton.Config({"SEQ_BLOCK_SIZE": 64}, num_warps=4, num_stages=3),
triton.Config({"SEQ_BLOCK_SIZE": 128}, num_warps=8, num_stages=3),
triton.Config({"SEQ_BLOCK_SIZE": 256}, num_warps=8, num_stages=4),
],
key=["HEAD_DIM", "MAX_SEQ_LEN", "PAGE_SIZE"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

MAX_SEQ_LEN as an autotune key will trigger frequent recompilation.

max_seq_len is computed as kv_indices.shape[0] * page_size, where kv_indices is the flat list of all page indices across all sequences. This total changes with every batch composition, causing Triton to retune the kernel repeatedly. The two-stage variant avoids this by using only HEAD_DIM, PAGE_SIZE, and HEAD_RATIO as autotune keys.

Consider removing MAX_SEQ_LEN from the key list and passing it only as a regular constexpr.

Proposed fix
 `@triton.autotune`(
     configs=[
         triton.Config({"SEQ_BLOCK_SIZE": 32}, num_warps=4, num_stages=2),
         triton.Config({"SEQ_BLOCK_SIZE": 64}, num_warps=4, num_stages=3),
         triton.Config({"SEQ_BLOCK_SIZE": 128}, num_warps=8, num_stages=3),
         triton.Config({"SEQ_BLOCK_SIZE": 256}, num_warps=8, num_stages=4),
     ],
-    key=["HEAD_DIM", "MAX_SEQ_LEN", "PAGE_SIZE"],
+    key=["HEAD_DIM", "PAGE_SIZE"],
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@triton.autotune(
configs=[
triton.Config({"SEQ_BLOCK_SIZE": 32}, num_warps=4, num_stages=2),
triton.Config({"SEQ_BLOCK_SIZE": 64}, num_warps=4, num_stages=3),
triton.Config({"SEQ_BLOCK_SIZE": 128}, num_warps=8, num_stages=3),
triton.Config({"SEQ_BLOCK_SIZE": 256}, num_warps=8, num_stages=4),
],
key=["HEAD_DIM", "MAX_SEQ_LEN", "PAGE_SIZE"],
`@triton.autotune`(
configs=[
triton.Config({"SEQ_BLOCK_SIZE": 32}, num_warps=4, num_stages=2),
triton.Config({"SEQ_BLOCK_SIZE": 64}, num_warps=4, num_stages=3),
triton.Config({"SEQ_BLOCK_SIZE": 128}, num_warps=8, num_stages=3),
triton.Config({"SEQ_BLOCK_SIZE": 256}, num_warps=8, num_stages=4),
],
key=["HEAD_DIM", "PAGE_SIZE"],
)
🤖 Prompt for AI Agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py`
around lines 192 - 199, The autotune key list in the triton.autotune decorator
includes "MAX_SEQ_LEN", which causes frequent recompiles because max_seq_len
varies per batch (computed from kv_indices.shape[0] * page_size); remove
"MAX_SEQ_LEN" from the key array in the `@triton.autotune` call and keep only
stable keys (e.g., "HEAD_DIM" and "PAGE_SIZE" or add "HEAD_RATIO" like the
two-stage variant), then pass max_seq_len as a regular constexpr/kernel argument
(not as a tuning key) so the kernel receives it at launch but Triton does not
retune on every batch; update any call sites and kernel signature that
referenced MAX_SEQ_LEN as an autotune key to accept it as a runtime/constexpr
parameter instead.

Comment on lines +662 to +663
@torch.library.custom_op("auto_deploy::triton_paged_one_stage_mha_with_cache", mutates_args=())
def triton_paged_one_stage_mha_with_cache(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same critical issue: mutates_args=() should include "kv_cache".

Identical to the two-stage variant — this op mutates kv_cache in-place via update_paged_kv_cache.

Proposed fix
-@torch.library.custom_op("auto_deploy::triton_paged_one_stage_mha_with_cache", mutates_args=())
+@torch.library.custom_op("auto_deploy::triton_paged_one_stage_mha_with_cache", mutates_args=("kv_cache",))
🤖 Prompt for AI Agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py`
around lines 662 - 663, The custom op declaration for
triton_paged_one_stage_mha_with_cache incorrectly specifies mutates_args=() even
though the op updates the kv cache in-place; update the decorator to include
"kv_cache" in the mutates_args tuple so PyTorch knows the op mutates that
argument (mirror the fix used in the two-stage variant and the in-place update
performed by update_paged_kv_cache).

Comment on lines +771 to +772
@torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=())
def triton_paged_mha_with_cache(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's locate and examine the file in question
find . -type f -name "triton_paged_attention.py" | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 143


🏁 Script executed:

#!/bin/bash
# Read the specific lines mentioned in the review (771-772 and surrounding context)
fd triton_paged_attention.py -x cat -n {} | sed -n '765,790p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1084


🏁 Script executed:

#!/bin/bash
# Search for mutates_args declarations across all attention implementations
rg -n "mutates_args" --type=py -B2 -A2 -g "**/attention/**"

Repository: NVIDIA/TensorRT-LLM

Length of output: 8826


🏁 Script executed:

#!/bin/bash
# Check what triton_paged_mha_with_cache returns
fd triton_paged_attention.py -x sed -n '771,850p' | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 2606


🏁 Script executed:

#!/bin/bash
# Check if triton_paged_one_stage_mha_with_cache (similar function) also mutates cache
fd triton_paged_attention_one_stage.py -x sed -n '662,750p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 2993


🏁 Script executed:

#!/bin/bash
# Look for update_paged_kv_cache in the file to understand the mutation
rg -n "update_paged_kv_cache" --type=py -B2 -A5 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 875


🏁 Script executed:

#!/bin/bash
# Check flashinfer_mha_with_cache to see if it also mutates cache
rg -n "def flashinfer_mha_with_cache" -A50 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 2252


🏁 Script executed:

#!/bin/bash
# Check torch_backend_mha_with_cache to see if it also mutates cache
rg -n "def torch_backend_mha_with_cache" -A50 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 1986


🏁 Script executed:

#!/bin/bash
# Check flattened_mha_with_cache in triton_attention.py
rg -n "def flattened_mha_with_cache" -A50 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 1977


🏁 Script executed:

#!/bin/bash
# Search for register_fake implementations for these mha_with_cache ops
rg -n "register_fake\|_fake" --type=py -B5 -A10 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py | grep -A15 "mha_with_cache"

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Check if there are register_fake for flashinfer and torch_backend
rg -n "@torch.library.register_fake|_mha_with_cache_fake" --type=py tensorrt_llm/_torch/auto_deploy/custom_ops/attention/ | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 564


🏁 Script executed:

#!/bin/bash
# Let's check if flashinfer.page.append_paged_kv_cache actually mutates its kv_cache argument
rg -n "append_paged_kv_cache" -B2 -A8 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 482


🏁 Script executed:

#!/bin/bash
# Check what torch_backend_mha_with_cache does with k_cache and v_cache
rg -n "def torch_backend_mha_with_cache" -A100 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py | head -120

Repository: NVIDIA/TensorRT-LLM

Length of output: 3600


🏁 Script executed:

#!/bin/bash
# Check if _torch_generate_mha and _torch_context_mha mutate their cache arguments
rg -n "_torch_generate_mha|_torch_context_mha" --type=py tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 183


🏁 Script executed:

#!/bin/bash
# Let's check the PyTorch documentation pattern by looking at how mutates_args is used elsewhere in the codebase
rg -n "mutates_args=\(" --type=py | grep -v "mutates_args=()" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1805


🏁 Script executed:

#!/bin/bash
# Final verification: Check if torch_backend and flattened also have cache mutation issues
rg -n "def _torch_generate_mha" -A30 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 1321


🏁 Script executed:

#!/bin/bash
# Verify triton_paged_one_stage also has the same issue
rg -n "@torch.library.custom_op.*triton_paged_one_stage_mha" -A2 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 227


mutates_args=() is incorrect — kv_cache is mutated in-place.

triton_paged_mha_with_cache calls update_paged_kv_cache(...) which writes to kv_cache via a Triton tl.store kernel. Declaring mutates_args=() tells torch.compile and FX tracing that no inputs are modified, which can lead to stale reads, incorrect CSE, or reordered operations when this op is compiled. This same issue affects other attention backends in this directory that call cache mutation functions (flashinfer_mha_with_cache, torch_backend_mha_with_cache, triton_paged_one_stage_mha_with_cache).

Proposed fix
-@torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=())
+@torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=("kv_cache",))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=())
def triton_paged_mha_with_cache(
`@torch.library.custom_op`("auto_deploy::triton_paged_mha_with_cache", mutates_args=("kv_cache",))
def triton_paged_mha_with_cache(
🤖 Prompt for AI Agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`
around lines 771 - 772, The custom op declaration incorrectly claims no inputs
are mutated; update the decorators to declare that kv_cache is mutated so Torch
can preserve ordering and avoid CSE issues: change
`@torch.library.custom_op`("auto_deploy::triton_paged_mha_with_cache",
mutates_args=()) to include the kv_cache argument (e.g.
mutates_args=("kv_cache",) or the appropriate positional index for kv_cache) and
apply the same change to flashinfer_mha_with_cache,
torch_backend_mha_with_cache, and triton_paged_one_stage_mha_with_cache so any
op that calls update_paged_kv_cache(...) explicitly reports the in-place
mutation.

Comment on lines +62 to +66
for i in range(batch_size):
num_pages = min(max_pages_per_seq, num_blocks - sum(kv_indptr[: i + 1].tolist()))
pages = list(range(int(kv_indptr[i].item()), int(kv_indptr[i].item()) + num_pages))
all_indices.extend(pages)
kv_indptr[i + 1] = kv_indptr[i] + num_pages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Bug in create_page_table: sum(kv_indptr[:i+1].tolist()) ≠ remaining pages.

Line 63 computes remaining capacity as num_blocks - sum(kv_indptr[:i+1].tolist()). Since kv_indptr holds cumulative page counts, summing all entries double-counts previous sequences (e.g., for i=2 it sums [0, p0, p0+p1] = 2*p0 + p1 instead of p0 + p1). The correct remaining count is num_blocks - int(kv_indptr[i].item()).

This utility is currently unused in the tests, so impact is nil, but it would silently produce wrong page tables if called with batch_size > 2.

Proposed fix
     for i in range(batch_size):
-        num_pages = min(max_pages_per_seq, num_blocks - sum(kv_indptr[: i + 1].tolist()))
+        num_pages = min(max_pages_per_seq, num_blocks - int(kv_indptr[i].item()))
🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/attention/test_triton_paged_attention.py`
around lines 62 - 66, The loop in create_page_table incorrectly computes
remaining capacity using sum(kv_indptr[: i + 1].tolist()) which double-counts
cumulative entries; replace that expression with the current cumulative value
int(kv_indptr[i].item()) so remaining pages = num_blocks -
int(kv_indptr[i].item()), and keep the rest of the logic (num_pages =
min(max_pages_per_seq, ...), pages range construction, updating kv_indptr[i +
1]) unchanged; update references to kv_indptr usage in the loop to use
int(kv_indptr[i].item()) for clarity.

Comment on lines +356 to +466
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seq_len", [64, 128])
def test_decode_vs_flashinfer(self, batch_size: int, seq_len: int):
"""Compare decode output against FlashInfer."""
import flashinfer

from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import (
triton_paged_decode,
update_paged_kv_cache,
)

n_heads = 32
n_kv_heads = 8
head_dim = 128
page_size = 16

num_pages_per_seq = (seq_len + page_size - 1) // page_size
num_blocks = batch_size * num_pages_per_seq + 10

# Create shared K, V data
k = torch.randn(
batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda"
)
v = torch.randn(
batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda"
)

# Query for decode
q = torch.randn(batch_size, n_heads, head_dim, dtype=torch.float16, device="cuda")

# Page table metadata
kv_indptr = torch.arange(
0,
(batch_size + 1) * num_pages_per_seq,
num_pages_per_seq,
dtype=torch.int32,
device="cuda",
)[: batch_size + 1]
kv_indices = torch.arange(
0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda"
)
last_token_in_page = seq_len % page_size
kv_last_page_len = torch.full(
(batch_size,),
last_token_in_page if last_token_in_page > 0 else page_size,
dtype=torch.int32,
device="cuda",
)

sm_scale = 1.0 / math.sqrt(head_dim)

# ===== Triton =====
kv_cache_triton = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim)
k_flat = k.reshape(batch_size * seq_len, n_kv_heads, head_dim)
v_flat = v.reshape(batch_size * seq_len, n_kv_heads, head_dim)
batch_indices = torch.repeat_interleave(
torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len
)
positions = torch.tile(
torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,)
)
update_paged_kv_cache(
k_flat, v_flat, batch_indices, positions, kv_cache_triton, kv_indices, kv_indptr
)
output_triton = triton_paged_decode(
q, kv_cache_triton, kv_indices, kv_indptr, kv_last_page_len, sm_scale
)

# ===== FlashInfer =====
kv_cache_fi = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim)
# Use FlashInfer's cache append
fi_batch_indices = batch_indices.clone()
fi_positions = positions.clone()
flashinfer.page.append_paged_kv_cache(
append_key=k_flat,
append_value=v_flat,
batch_indices=fi_batch_indices,
positions=fi_positions,
paged_kv_cache=kv_cache_fi,
kv_indices=kv_indices,
kv_indptr=kv_indptr,
kv_last_page_len=kv_last_page_len,
kv_layout="HND",
)

# Use FlashInfer decode
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda")
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace, "HND", use_tensor_cores=True
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_len,
n_heads,
n_kv_heads,
head_dim,
page_size,
q_data_type=q.dtype,
kv_data_type=kv_cache_fi.dtype,
sm_scale=sm_scale,
)
output_fi = wrapper.run(q, kv_cache_fi)

# Compare
torch.testing.assert_close(output_triton.float(), output_fi.float(), rtol=1e-2, atol=1e-2)

@pytest.mark.skipif(
not pytest.importorskip("flashinfer", reason="FlashInfer not installed"),
reason="FlashInfer not installed",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

test_decode_vs_flashinfer will crash when FlashInfer is not installed; test_prefill_vs_flashinfer skipif pattern is broken.

test_decode_vs_flashinfer has no skip logic and will raise ImportError if flashinfer is unavailable. test_prefill_vs_flashinfer wraps pytest.importorskip(...) inside pytest.mark.skipif(), but importorskip raises pytest.skip.Exception at collection time rather than returning a boolean — so the decorator either disrupts collection or is a no-op when the package is present.

Use a top-level availability flag for both tests:

Proposed fix

At the top of the file (after existing imports):

try:
    import flashinfer  # noqa: F401
    _HAS_FLASHINFER = True
except ImportError:
    _HAS_FLASHINFER = False

Then apply to both test methods:

 class TestFlashInferComparison:
     """Tests comparing Triton implementation against FlashInfer."""

+    `@pytest.mark.skipif`(not _HAS_FLASHINFER, reason="FlashInfer not installed")
     `@pytest.mark.parametrize`("batch_size", [1, 4])
     `@pytest.mark.parametrize`("seq_len", [64, 128])
     def test_decode_vs_flashinfer(self, batch_size: int, seq_len: int):
         ...

-    `@pytest.mark.skipif`(
-        not pytest.importorskip("flashinfer", reason="FlashInfer not installed"),
-        reason="FlashInfer not installed",
-    )
+    `@pytest.mark.skipif`(not _HAS_FLASHINFER, reason="FlashInfer not installed")
     def test_prefill_vs_flashinfer(self):
         ...
🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/attention/test_triton_paged_attention.py`
around lines 356 - 466, The tests import flashinfer inside
test_decode_vs_flashinfer and misuse pytest.importorskip in
test_prefill_vs_flashinfer; add a top-level availability flag (try: import
flashinfer; _HAS_FLASHINFER = True except ImportError: _HAS_FLASHINFER = False)
near the module imports, remove the inline import from
test_decode_vs_flashinfer, and annotate both test_decode_vs_flashinfer and
test_prefill_vs_flashinfer with `@pytest.mark.skipif`(not _HAS_FLASHINFER,
reason="FlashInfer not installed") so both tests are skipped when FlashInfer is
absent (referencing the test function names test_decode_vs_flashinfer and
test_prefill_vs_flashinfer and the _HAS_FLASHINFER symbol).



def update_paged_kv_cache(
k: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and else where, if there are assumptions the input tensors to the kernel are contiguous, then we should add asserts. Typically, when integrating these kernels in the model if contiguity requirement is violated, it can lead to subtle numerical issues and/or IMAs.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz nvchenghaoz marked this pull request as draft February 7, 2026 00:08
@nvchenghaoz nvchenghaoz force-pushed the chenghao/triton-paged-attention branch from aef9446 to 5ab4062 Compare February 10, 2026 18:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants