[#11146][feat] AutoDeploy: Add triton paged attention#11355
[#11146][feat] AutoDeploy: Add triton paged attention#11355nvchenghaoz wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR removes explicit module exports from the attention package's Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant MHAOp as MHA with Cache Op
participant MetaPrep as Metadata Prep
participant KVCache as KV Cache
participant UpdateKernel as KV Update Kernel
participant PrefillKernel as Prefill Kernel
participant DecodeKernel as Decode Kernel
Client->>MHAOp: q, k, v, cache_config
MHAOp->>MetaPrep: position_ids, batch_info, cu_seqlen
MetaPrep->>MetaPrep: Compute batch_indices, positions
MetaPrep-->>MHAOp: metadata (batch_indices, positions)
MHAOp->>UpdateKernel: k, v, batch_indices, positions, page_table
UpdateKernel->>KVCache: Write k, v to paged blocks
KVCache-->>UpdateKernel: Cache updated
alt Has prefill tokens
MHAOp->>PrefillKernel: q, kv_cache, page_indices, causal_mask
PrefillKernel->>PrefillKernel: Online softmax over KV pages
PrefillKernel-->>MHAOp: Prefill output
end
alt Has decode tokens
MHAOp->>DecodeKernel: q, kv_cache, page_indices
DecodeKernel->>DecodeKernel: Attention computation (multi-stage or single-stage)
DecodeKernel-->>MHAOp: Decode output
end
MHAOp-->>Client: Combined attention output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py`:
- Around line 662-663: The custom op declaration for
triton_paged_one_stage_mha_with_cache incorrectly specifies mutates_args=() even
though the op updates the kv cache in-place; update the decorator to include
"kv_cache" in the mutates_args tuple so PyTorch knows the op mutates that
argument (mirror the fix used in the two-stage variant and the in-place update
performed by update_paged_kv_cache).
- Around line 192-199: The autotune key list in the triton.autotune decorator
includes "MAX_SEQ_LEN", which causes frequent recompiles because max_seq_len
varies per batch (computed from kv_indices.shape[0] * page_size); remove
"MAX_SEQ_LEN" from the key array in the `@triton.autotune` call and keep only
stable keys (e.g., "HEAD_DIM" and "PAGE_SIZE" or add "HEAD_RATIO" like the
two-stage variant), then pass max_seq_len as a regular constexpr/kernel argument
(not as a tuning key) so the kernel receives it at launch but Triton does not
retune on every batch; update any call sites and kernel signature that
referenced MAX_SEQ_LEN as an autotune key to accept it as a runtime/constexpr
parameter instead.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 771-772: The custom op declaration incorrectly claims no inputs
are mutated; update the decorators to declare that kv_cache is mutated so Torch
can preserve ordering and avoid CSE issues: change
`@torch.library.custom_op`("auto_deploy::triton_paged_mha_with_cache",
mutates_args=()) to include the kv_cache argument (e.g.
mutates_args=("kv_cache",) or the appropriate positional index for kv_cache) and
apply the same change to flashinfer_mha_with_cache,
torch_backend_mha_with_cache, and triton_paged_one_stage_mha_with_cache so any
op that calls update_paged_kv_cache(...) explicitly reports the in-place
mutation.
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/attention/test_triton_paged_attention.py`:
- Around line 62-66: The loop in create_page_table incorrectly computes
remaining capacity using sum(kv_indptr[: i + 1].tolist()) which double-counts
cumulative entries; replace that expression with the current cumulative value
int(kv_indptr[i].item()) so remaining pages = num_blocks -
int(kv_indptr[i].item()), and keep the rest of the logic (num_pages =
min(max_pages_per_seq, ...), pages range construction, updating kv_indptr[i +
1]) unchanged; update references to kv_indptr usage in the loop to use
int(kv_indptr[i].item()) for clarity.
- Around line 356-466: The tests import flashinfer inside
test_decode_vs_flashinfer and misuse pytest.importorskip in
test_prefill_vs_flashinfer; add a top-level availability flag (try: import
flashinfer; _HAS_FLASHINFER = True except ImportError: _HAS_FLASHINFER = False)
near the module imports, remove the inline import from
test_decode_vs_flashinfer, and annotate both test_decode_vs_flashinfer and
test_prefill_vs_flashinfer with `@pytest.mark.skipif`(not _HAS_FLASHINFER,
reason="FlashInfer not installed") so both tests are skipped when FlashInfer is
absent (referencing the test function names test_decode_vs_flashinfer and
test_prefill_vs_flashinfer and the _HAS_FLASHINFER symbol).
🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (1)
163-166:_get_num_splitshardcodes device 0 for SM count.
torch.cuda.get_device_properties(0)always queries GPU 0. On multi-GPU setups the model may run on a different device, yielding the wrongmulti_processor_countand a suboptimal split count. Consider accepting the device (or a tensor to infer it from) so the correct GPU is queried.Proposed fix
-def _get_num_splits(max_seq_len: int, batch_size: int, n_kv_heads: int, page_size: int) -> int: +def _get_num_splits( + max_seq_len: int, batch_size: int, n_kv_heads: int, page_size: int, device: torch.device = torch.device("cuda", 0) +) -> int: ... - num_sms = torch.cuda.get_device_properties(0).multi_processor_count + num_sms = torch.cuda.get_device_properties(device).multi_processor_countAnd at the call site in
triton_paged_decode:- num_splits = _get_num_splits(max_seq_len, batch_size, n_kv_heads, page_size) + num_splits = _get_num_splits(max_seq_len, batch_size, n_kv_heads, page_size, device=q.device)tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py (1)
49-61: Significant code duplication withtriton_paged_attention.py.
KV_LAYOUT,_get_sm_scale,_update_paged_kv_cache_kernel,update_paged_kv_cache, the metadata preparation ops, and most of the descriptor boilerplate are copied verbatim from the two-stage file. Consider extracting these shared utilities into a common module (e.g.,_triton_paged_common.py) to avoid maintaining identical code in two places.Given the one-stage variant is described as potentially being removed later, this can be deferred — but it's worth tracking to avoid future divergence.
Also applies to: 64-182
| @triton.autotune( | ||
| configs=[ | ||
| triton.Config({"SEQ_BLOCK_SIZE": 32}, num_warps=4, num_stages=2), | ||
| triton.Config({"SEQ_BLOCK_SIZE": 64}, num_warps=4, num_stages=3), | ||
| triton.Config({"SEQ_BLOCK_SIZE": 128}, num_warps=8, num_stages=3), | ||
| triton.Config({"SEQ_BLOCK_SIZE": 256}, num_warps=8, num_stages=4), | ||
| ], | ||
| key=["HEAD_DIM", "MAX_SEQ_LEN", "PAGE_SIZE"], |
There was a problem hiding this comment.
MAX_SEQ_LEN as an autotune key will trigger frequent recompilation.
max_seq_len is computed as kv_indices.shape[0] * page_size, where kv_indices is the flat list of all page indices across all sequences. This total changes with every batch composition, causing Triton to retune the kernel repeatedly. The two-stage variant avoids this by using only HEAD_DIM, PAGE_SIZE, and HEAD_RATIO as autotune keys.
Consider removing MAX_SEQ_LEN from the key list and passing it only as a regular constexpr.
Proposed fix
`@triton.autotune`(
configs=[
triton.Config({"SEQ_BLOCK_SIZE": 32}, num_warps=4, num_stages=2),
triton.Config({"SEQ_BLOCK_SIZE": 64}, num_warps=4, num_stages=3),
triton.Config({"SEQ_BLOCK_SIZE": 128}, num_warps=8, num_stages=3),
triton.Config({"SEQ_BLOCK_SIZE": 256}, num_warps=8, num_stages=4),
],
- key=["HEAD_DIM", "MAX_SEQ_LEN", "PAGE_SIZE"],
+ key=["HEAD_DIM", "PAGE_SIZE"],
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @triton.autotune( | |
| configs=[ | |
| triton.Config({"SEQ_BLOCK_SIZE": 32}, num_warps=4, num_stages=2), | |
| triton.Config({"SEQ_BLOCK_SIZE": 64}, num_warps=4, num_stages=3), | |
| triton.Config({"SEQ_BLOCK_SIZE": 128}, num_warps=8, num_stages=3), | |
| triton.Config({"SEQ_BLOCK_SIZE": 256}, num_warps=8, num_stages=4), | |
| ], | |
| key=["HEAD_DIM", "MAX_SEQ_LEN", "PAGE_SIZE"], | |
| `@triton.autotune`( | |
| configs=[ | |
| triton.Config({"SEQ_BLOCK_SIZE": 32}, num_warps=4, num_stages=2), | |
| triton.Config({"SEQ_BLOCK_SIZE": 64}, num_warps=4, num_stages=3), | |
| triton.Config({"SEQ_BLOCK_SIZE": 128}, num_warps=8, num_stages=3), | |
| triton.Config({"SEQ_BLOCK_SIZE": 256}, num_warps=8, num_stages=4), | |
| ], | |
| key=["HEAD_DIM", "PAGE_SIZE"], | |
| ) |
🤖 Prompt for AI Agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py`
around lines 192 - 199, The autotune key list in the triton.autotune decorator
includes "MAX_SEQ_LEN", which causes frequent recompiles because max_seq_len
varies per batch (computed from kv_indices.shape[0] * page_size); remove
"MAX_SEQ_LEN" from the key array in the `@triton.autotune` call and keep only
stable keys (e.g., "HEAD_DIM" and "PAGE_SIZE" or add "HEAD_RATIO" like the
two-stage variant), then pass max_seq_len as a regular constexpr/kernel argument
(not as a tuning key) so the kernel receives it at launch but Triton does not
retune on every batch; update any call sites and kernel signature that
referenced MAX_SEQ_LEN as an autotune key to accept it as a runtime/constexpr
parameter instead.
| @torch.library.custom_op("auto_deploy::triton_paged_one_stage_mha_with_cache", mutates_args=()) | ||
| def triton_paged_one_stage_mha_with_cache( |
There was a problem hiding this comment.
Same critical issue: mutates_args=() should include "kv_cache".
Identical to the two-stage variant — this op mutates kv_cache in-place via update_paged_kv_cache.
Proposed fix
-@torch.library.custom_op("auto_deploy::triton_paged_one_stage_mha_with_cache", mutates_args=())
+@torch.library.custom_op("auto_deploy::triton_paged_one_stage_mha_with_cache", mutates_args=("kv_cache",))🤖 Prompt for AI Agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.py`
around lines 662 - 663, The custom op declaration for
triton_paged_one_stage_mha_with_cache incorrectly specifies mutates_args=() even
though the op updates the kv cache in-place; update the decorator to include
"kv_cache" in the mutates_args tuple so PyTorch knows the op mutates that
argument (mirror the fix used in the two-stage variant and the in-place update
performed by update_paged_kv_cache).
| @torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=()) | ||
| def triton_paged_mha_with_cache( |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's locate and examine the file in question
find . -type f -name "triton_paged_attention.py" | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 143
🏁 Script executed:
#!/bin/bash
# Read the specific lines mentioned in the review (771-772 and surrounding context)
fd triton_paged_attention.py -x cat -n {} | sed -n '765,790p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1084
🏁 Script executed:
#!/bin/bash
# Search for mutates_args declarations across all attention implementations
rg -n "mutates_args" --type=py -B2 -A2 -g "**/attention/**"Repository: NVIDIA/TensorRT-LLM
Length of output: 8826
🏁 Script executed:
#!/bin/bash
# Check what triton_paged_mha_with_cache returns
fd triton_paged_attention.py -x sed -n '771,850p' | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 2606
🏁 Script executed:
#!/bin/bash
# Check if triton_paged_one_stage_mha_with_cache (similar function) also mutates cache
fd triton_paged_attention_one_stage.py -x sed -n '662,750p'Repository: NVIDIA/TensorRT-LLM
Length of output: 2993
🏁 Script executed:
#!/bin/bash
# Look for update_paged_kv_cache in the file to understand the mutation
rg -n "update_paged_kv_cache" --type=py -B2 -A5 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 875
🏁 Script executed:
#!/bin/bash
# Check flashinfer_mha_with_cache to see if it also mutates cache
rg -n "def flashinfer_mha_with_cache" -A50 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 2252
🏁 Script executed:
#!/bin/bash
# Check torch_backend_mha_with_cache to see if it also mutates cache
rg -n "def torch_backend_mha_with_cache" -A50 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 1986
🏁 Script executed:
#!/bin/bash
# Check flattened_mha_with_cache in triton_attention.py
rg -n "def flattened_mha_with_cache" -A50 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 1977
🏁 Script executed:
#!/bin/bash
# Search for register_fake implementations for these mha_with_cache ops
rg -n "register_fake\|_fake" --type=py -B5 -A10 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py | grep -A15 "mha_with_cache"Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Check if there are register_fake for flashinfer and torch_backend
rg -n "@torch.library.register_fake|_mha_with_cache_fake" --type=py tensorrt_llm/_torch/auto_deploy/custom_ops/attention/ | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 564
🏁 Script executed:
#!/bin/bash
# Let's check if flashinfer.page.append_paged_kv_cache actually mutates its kv_cache argument
rg -n "append_paged_kv_cache" -B2 -A8 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 482
🏁 Script executed:
#!/bin/bash
# Check what torch_backend_mha_with_cache does with k_cache and v_cache
rg -n "def torch_backend_mha_with_cache" -A100 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py | head -120Repository: NVIDIA/TensorRT-LLM
Length of output: 3600
🏁 Script executed:
#!/bin/bash
# Check if _torch_generate_mha and _torch_context_mha mutate their cache arguments
rg -n "_torch_generate_mha|_torch_context_mha" --type=py tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 183
🏁 Script executed:
#!/bin/bash
# Let's check the PyTorch documentation pattern by looking at how mutates_args is used elsewhere in the codebase
rg -n "mutates_args=\(" --type=py | grep -v "mutates_args=()" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1805
🏁 Script executed:
#!/bin/bash
# Final verification: Check if torch_backend and flattened also have cache mutation issues
rg -n "def _torch_generate_mha" -A30 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 1321
🏁 Script executed:
#!/bin/bash
# Verify triton_paged_one_stage also has the same issue
rg -n "@torch.library.custom_op.*triton_paged_one_stage_mha" -A2 tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention_one_stage.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 227
mutates_args=() is incorrect — kv_cache is mutated in-place.
triton_paged_mha_with_cache calls update_paged_kv_cache(...) which writes to kv_cache via a Triton tl.store kernel. Declaring mutates_args=() tells torch.compile and FX tracing that no inputs are modified, which can lead to stale reads, incorrect CSE, or reordered operations when this op is compiled. This same issue affects other attention backends in this directory that call cache mutation functions (flashinfer_mha_with_cache, torch_backend_mha_with_cache, triton_paged_one_stage_mha_with_cache).
Proposed fix
-@torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=())
+@torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=("kv_cache",))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @torch.library.custom_op("auto_deploy::triton_paged_mha_with_cache", mutates_args=()) | |
| def triton_paged_mha_with_cache( | |
| `@torch.library.custom_op`("auto_deploy::triton_paged_mha_with_cache", mutates_args=("kv_cache",)) | |
| def triton_paged_mha_with_cache( |
🤖 Prompt for AI Agents
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`
around lines 771 - 772, The custom op declaration incorrectly claims no inputs
are mutated; update the decorators to declare that kv_cache is mutated so Torch
can preserve ordering and avoid CSE issues: change
`@torch.library.custom_op`("auto_deploy::triton_paged_mha_with_cache",
mutates_args=()) to include the kv_cache argument (e.g.
mutates_args=("kv_cache",) or the appropriate positional index for kv_cache) and
apply the same change to flashinfer_mha_with_cache,
torch_backend_mha_with_cache, and triton_paged_one_stage_mha_with_cache so any
op that calls update_paged_kv_cache(...) explicitly reports the in-place
mutation.
| for i in range(batch_size): | ||
| num_pages = min(max_pages_per_seq, num_blocks - sum(kv_indptr[: i + 1].tolist())) | ||
| pages = list(range(int(kv_indptr[i].item()), int(kv_indptr[i].item()) + num_pages)) | ||
| all_indices.extend(pages) | ||
| kv_indptr[i + 1] = kv_indptr[i] + num_pages |
There was a problem hiding this comment.
Bug in create_page_table: sum(kv_indptr[:i+1].tolist()) ≠ remaining pages.
Line 63 computes remaining capacity as num_blocks - sum(kv_indptr[:i+1].tolist()). Since kv_indptr holds cumulative page counts, summing all entries double-counts previous sequences (e.g., for i=2 it sums [0, p0, p0+p1] = 2*p0 + p1 instead of p0 + p1). The correct remaining count is num_blocks - int(kv_indptr[i].item()).
This utility is currently unused in the tests, so impact is nil, but it would silently produce wrong page tables if called with batch_size > 2.
Proposed fix
for i in range(batch_size):
- num_pages = min(max_pages_per_seq, num_blocks - sum(kv_indptr[: i + 1].tolist()))
+ num_pages = min(max_pages_per_seq, num_blocks - int(kv_indptr[i].item()))🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/attention/test_triton_paged_attention.py`
around lines 62 - 66, The loop in create_page_table incorrectly computes
remaining capacity using sum(kv_indptr[: i + 1].tolist()) which double-counts
cumulative entries; replace that expression with the current cumulative value
int(kv_indptr[i].item()) so remaining pages = num_blocks -
int(kv_indptr[i].item()), and keep the rest of the logic (num_pages =
min(max_pages_per_seq, ...), pages range construction, updating kv_indptr[i +
1]) unchanged; update references to kv_indptr usage in the loop to use
int(kv_indptr[i].item()) for clarity.
| @pytest.mark.parametrize("batch_size", [1, 4]) | ||
| @pytest.mark.parametrize("seq_len", [64, 128]) | ||
| def test_decode_vs_flashinfer(self, batch_size: int, seq_len: int): | ||
| """Compare decode output against FlashInfer.""" | ||
| import flashinfer | ||
|
|
||
| from tensorrt_llm._torch.auto_deploy.custom_ops.attention.triton_paged_attention import ( | ||
| triton_paged_decode, | ||
| update_paged_kv_cache, | ||
| ) | ||
|
|
||
| n_heads = 32 | ||
| n_kv_heads = 8 | ||
| head_dim = 128 | ||
| page_size = 16 | ||
|
|
||
| num_pages_per_seq = (seq_len + page_size - 1) // page_size | ||
| num_blocks = batch_size * num_pages_per_seq + 10 | ||
|
|
||
| # Create shared K, V data | ||
| k = torch.randn( | ||
| batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" | ||
| ) | ||
| v = torch.randn( | ||
| batch_size, seq_len, n_kv_heads, head_dim, dtype=torch.float16, device="cuda" | ||
| ) | ||
|
|
||
| # Query for decode | ||
| q = torch.randn(batch_size, n_heads, head_dim, dtype=torch.float16, device="cuda") | ||
|
|
||
| # Page table metadata | ||
| kv_indptr = torch.arange( | ||
| 0, | ||
| (batch_size + 1) * num_pages_per_seq, | ||
| num_pages_per_seq, | ||
| dtype=torch.int32, | ||
| device="cuda", | ||
| )[: batch_size + 1] | ||
| kv_indices = torch.arange( | ||
| 0, batch_size * num_pages_per_seq, dtype=torch.int32, device="cuda" | ||
| ) | ||
| last_token_in_page = seq_len % page_size | ||
| kv_last_page_len = torch.full( | ||
| (batch_size,), | ||
| last_token_in_page if last_token_in_page > 0 else page_size, | ||
| dtype=torch.int32, | ||
| device="cuda", | ||
| ) | ||
|
|
||
| sm_scale = 1.0 / math.sqrt(head_dim) | ||
|
|
||
| # ===== Triton ===== | ||
| kv_cache_triton = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) | ||
| k_flat = k.reshape(batch_size * seq_len, n_kv_heads, head_dim) | ||
| v_flat = v.reshape(batch_size * seq_len, n_kv_heads, head_dim) | ||
| batch_indices = torch.repeat_interleave( | ||
| torch.arange(batch_size, device="cuda", dtype=torch.int32), seq_len | ||
| ) | ||
| positions = torch.tile( | ||
| torch.arange(seq_len, device="cuda", dtype=torch.int32), (batch_size,) | ||
| ) | ||
| update_paged_kv_cache( | ||
| k_flat, v_flat, batch_indices, positions, kv_cache_triton, kv_indices, kv_indptr | ||
| ) | ||
| output_triton = triton_paged_decode( | ||
| q, kv_cache_triton, kv_indices, kv_indptr, kv_last_page_len, sm_scale | ||
| ) | ||
|
|
||
| # ===== FlashInfer ===== | ||
| kv_cache_fi = create_paged_kv_cache(num_blocks, page_size, n_kv_heads, head_dim) | ||
| # Use FlashInfer's cache append | ||
| fi_batch_indices = batch_indices.clone() | ||
| fi_positions = positions.clone() | ||
| flashinfer.page.append_paged_kv_cache( | ||
| append_key=k_flat, | ||
| append_value=v_flat, | ||
| batch_indices=fi_batch_indices, | ||
| positions=fi_positions, | ||
| paged_kv_cache=kv_cache_fi, | ||
| kv_indices=kv_indices, | ||
| kv_indptr=kv_indptr, | ||
| kv_last_page_len=kv_last_page_len, | ||
| kv_layout="HND", | ||
| ) | ||
|
|
||
| # Use FlashInfer decode | ||
| workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda") | ||
| wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( | ||
| workspace, "HND", use_tensor_cores=True | ||
| ) | ||
| wrapper.plan( | ||
| kv_indptr, | ||
| kv_indices, | ||
| kv_last_page_len, | ||
| n_heads, | ||
| n_kv_heads, | ||
| head_dim, | ||
| page_size, | ||
| q_data_type=q.dtype, | ||
| kv_data_type=kv_cache_fi.dtype, | ||
| sm_scale=sm_scale, | ||
| ) | ||
| output_fi = wrapper.run(q, kv_cache_fi) | ||
|
|
||
| # Compare | ||
| torch.testing.assert_close(output_triton.float(), output_fi.float(), rtol=1e-2, atol=1e-2) | ||
|
|
||
| @pytest.mark.skipif( | ||
| not pytest.importorskip("flashinfer", reason="FlashInfer not installed"), | ||
| reason="FlashInfer not installed", | ||
| ) |
There was a problem hiding this comment.
test_decode_vs_flashinfer will crash when FlashInfer is not installed; test_prefill_vs_flashinfer skipif pattern is broken.
test_decode_vs_flashinfer has no skip logic and will raise ImportError if flashinfer is unavailable. test_prefill_vs_flashinfer wraps pytest.importorskip(...) inside pytest.mark.skipif(), but importorskip raises pytest.skip.Exception at collection time rather than returning a boolean — so the decorator either disrupts collection or is a no-op when the package is present.
Use a top-level availability flag for both tests:
Proposed fix
At the top of the file (after existing imports):
try:
import flashinfer # noqa: F401
_HAS_FLASHINFER = True
except ImportError:
_HAS_FLASHINFER = FalseThen apply to both test methods:
class TestFlashInferComparison:
"""Tests comparing Triton implementation against FlashInfer."""
+ `@pytest.mark.skipif`(not _HAS_FLASHINFER, reason="FlashInfer not installed")
`@pytest.mark.parametrize`("batch_size", [1, 4])
`@pytest.mark.parametrize`("seq_len", [64, 128])
def test_decode_vs_flashinfer(self, batch_size: int, seq_len: int):
...
- `@pytest.mark.skipif`(
- not pytest.importorskip("flashinfer", reason="FlashInfer not installed"),
- reason="FlashInfer not installed",
- )
+ `@pytest.mark.skipif`(not _HAS_FLASHINFER, reason="FlashInfer not installed")
def test_prefill_vs_flashinfer(self):
...🤖 Prompt for AI Agents
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/attention/test_triton_paged_attention.py`
around lines 356 - 466, The tests import flashinfer inside
test_decode_vs_flashinfer and misuse pytest.importorskip in
test_prefill_vs_flashinfer; add a top-level availability flag (try: import
flashinfer; _HAS_FLASHINFER = True except ImportError: _HAS_FLASHINFER = False)
near the module imports, remove the inline import from
test_decode_vs_flashinfer, and annotate both test_decode_vs_flashinfer and
test_prefill_vs_flashinfer with `@pytest.mark.skipif`(not _HAS_FLASHINFER,
reason="FlashInfer not installed") so both tests are skipped when FlashInfer is
absent (referencing the test function names test_decode_vs_flashinfer and
test_prefill_vs_flashinfer and the _HAS_FLASHINFER symbol).
|
|
||
|
|
||
| def update_paged_kv_cache( | ||
| k: torch.Tensor, |
There was a problem hiding this comment.
here and else where, if there are assumptions the input tensors to the kernel are contiguous, then we should add asserts. Typically, when integrating these kernels in the model if contiguity requirement is violated, it can lead to subtle numerical issues and/or IMAs.
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
aef9446 to
5ab4062
Compare
This is the part 1 to add the triton paged attention as the AutoDeploy attention backend. #11146
The changes in this PR:
The perf for Llama-3.1-8B-Instruct for triton_paged vs. flashinfer
Summary by CodeRabbit
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.