Skip to content

Conversation

@grf53
Copy link
Contributor

@grf53 grf53 commented Jan 2, 2026

Description

This PR improves the sliding window, previously considered only in _attention_prefil() and _attention_decode()(CrossAttention), to also be applied in the _attention_prefill_ragged()(SelfAttention) function.
This allows us to calculate the attention value more properly for the requests with long prefill lengths.

Background

(Note that in TVM, CrossAttention refers to the attention of Q in a new seq for K and V already stored in the KV Cache, while SelfAttention refers to the self-attention for Q, K, and V within a new seq.)

In layers applying a sliding window in models with SWA (e.g., Mistral) or per-layer SWA (e.g., Gemma 3), the CrossAttention uses a function with sliding_window=True.

However, SelfAttention is not affected by applying sliding window or not in the current implementation. Since the attention of a new seq is the result of merging SelfAttention and CrossAttention, if the newly input sequence is longer than the sliding window size, the resulting attention cannot be considered appropriate.

In past SWA-based models, their sliding_window_sizes were often large enough, which may not have been a problem when each prefill input was not long. However, with the advent of per-layer SWA, models using shorter sliding_window_sizes (such as gpt-oss with 128) emerged, and with the prevalence of agentic usage, a single prefill can become very long, which can lead to problems in performance or stability.

Fixes

In those functions with sliding_window=True in CrossAttention, the length of the KV chunk to be referenced is adjusted with sliding_window_offset and sink_size additionally provided in length_info, thereby applying the sliding window.

However, _attention_prefill_ragged() in SelfAttention must apply a sliding window by applying an appropriate causal mask to the given tensors. Currently, _causal_mask() only uses the appropriate upper bound condition, forming a lower triangular matrix. To apply a sliding window, a lower bound condition must be added to form a diagonal band matrix.

AS-IS

condition = col <= row

TO-BE

condition = (col <= row) and (col >= row - window_size)

However, since the sliding_window_size value is not provided for the function, we needed to provide it additionally. (I confirmed that using the sliding_window_offset value directly is not appropriate.)

Notes

There were some hard-coded values for gemma 3 models added in #17928, and I removed them in this PR.
There is a mention of introducing optional parameters for those values, so the way I chose could be not match to the existing plan or intention.
Or there could be a more preferred way to add lower bound condition in causal mask, I think.

Please kindly let me know if there is a comment.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @grf53, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the KVCache's ability to handle sliding window attention for large language models, particularly when processing long prefill sequences. By extending sliding window logic to self-attention operations within ragged prefill, it ensures more accurate attention calculations and improves compatibility with models utilizing per-layer sliding window attention, addressing potential performance and stability issues with shorter window sizes.

Highlights

  • Sliding Window for Ragged Prefill: Extended sliding window attention to the _attention_prefill_ragged() function, which handles self-attention for new sequences, ensuring proper attention calculation for requests with long prefill lengths.
  • Enhanced Causal Masking: Modified the _causal_mask function to incorporate a lower bound condition based on sliding_window_size, transforming it into a diagonal band matrix for accurate sliding window application.
  • Dynamic Sliding Window Size: Introduced sliding_window_size as a parameter to relevant attention functions and classes, allowing for dynamic configuration rather than hardcoded values.
  • Gemma 3 Hardcode Removal: Eliminated specific hard-coded values related to Gemma 3 models, promoting a more generalized and flexible implementation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a valuable contribution that extends sliding window attention support to ragged prefill operations, which is crucial for modern language models. The changes primarily involve plumbing the sliding_window_size parameter through the attention kernels and cleaning up hardcoded values. I've identified a critical bug related to layer indexing, a potential off-by-one error in the sliding window logic, and an opportunity to improve the robustness of handling the sliding window size. Addressing these points will help ensure the correctness and stability of the implementation.

Comment on lines +772 to +773
if (sliding_window_size_ == -1)
sliding_window_size_ = sliding_window_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to set sliding_window_size_ is a bit fragile. If EnableSlidingWindowForSeq is called for different sequences with different sliding_window_size values, only the first one will take effect, and subsequent calls with different sizes will be silently ignored. This could lead to unexpected behavior. To make this assumption explicit and prevent silent errors, I'd suggest adding a check to ensure that all calls use a consistent sliding_window_size.

Suggested change
if (sliding_window_size_ == -1)
sliding_window_size_ = sliding_window_size;
if (sliding_window_size_ == -1) {
sliding_window_size_ = sliding_window_size;
} else {
ICHECK_EQ(sliding_window_size_, sliding_window_size)
<< "Inconsistent sliding window sizes are not supported. Previously got "
<< sliding_window_size_ << ", but now got " << sliding_window_size;
}

grf53 and others added 2 commits January 5, 2026 19:32
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Adjust causal mask calculation to include an additional condition.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant