[Stateful] Implement length-aware keying to minimize padding in BatchElements (Part 2/3)#37565
[Stateful] Implement length-aware keying to minimize padding in BatchElements (Part 2/3)#37565Eliaaazzz wants to merge 4 commits intoapache:masterfrom
Conversation
Summary of ChangesHello @Eliaaazzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization to the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Checks are failing. Will not request review until checks are succeeding. If you'd like to override that behavior, comment |
9f6b1c2 to
92b546a
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #37565 +/- ##
=========================================
Coverage 35.88% 35.88%
Complexity 1691 1691
=========================================
Files 1063 1063
Lines 166721 166752 +31
Branches 1227 1227
=========================================
+ Hits 59832 59844 +12
- Misses 104694 104713 +19
Partials 2195 2195
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Assigning reviewers: R: @tvalentyn for label python. Note: If you would like to opt out of this review, comment Available commands:
The PR bot will only process comments in the main thread (not review comments). |
8926e75 to
77b23a7
Compare
- Add length_fn and bucket_boundaries parameters to ModelHandler.__init__ to support length-aware bucketed keying for ML inference batching - Add WithLengthBucketKey DoFn to route elements by length buckets - Update BatchElements to support length-aware batching when max_batch_duration_secs is set, reducing padding waste for variable-length sequences (e.g., NLP workloads) - Default bucket boundaries: [16, 32, 64, 128, 256, 512] - Add comprehensive tests validating bucket assignment, mixed-length batching, and padding efficiency improvements (77% vs 68% on bimodal data) - All formatting (yapf) and lint (pylint 10/10) checks passed
2f68510 to
8713eac
Compare
8713eac to
47b5a9b
Compare
|
Reminder, please take a look at this pr: @tvalentyn |
6af4d2e to
f820874
Compare
|
Hi @damccorm, just a gentle ping on this when you have a spare moment. No rush at all. Just wanted to note that this PR and PR 2 are blocking the final integration PR, so I'd love to get your thoughts on this core logic so I can adjust the downstream work if needed. |
a889c3f to
d6d33c8
Compare
d6d33c8 to
53454f3
Compare
|
/gemini review |
damccorm
left a comment
There was a problem hiding this comment.
Thanks - this is neat! Just had minor feedback, it generally LGTM
| max_batch_weight: Optional[int] = None, | ||
| element_size_fn: Optional[Callable[[Any], int]] = None, | ||
| length_fn: Optional[Callable[[Any], int]] = None, | ||
| bucket_boundaries: Optional[list[int]] = None, |
There was a problem hiding this comment.
Can we make it clear these are batching parameters? e.g. batch_length_fn and batch_bucket_boundaries?
| length_fn: a callable mapping an element to its length. When set with | ||
| max_batch_duration_secs, enables length-aware bucketed keying so | ||
| elements of similar length are batched together. | ||
| bucket_boundaries: sorted list of positive boundary values for length |
There was a problem hiding this comment.
Could we add more data to this description, similar to below?
There was a problem hiding this comment.
Code Review
This PR introduces length-aware keying in BatchElements to improve batching efficiency, especially for variable-length inputs. The changes include a new WithLengthBucketKey DoFn, API updates to BatchElements and ModelHandler to support length-based bucketing, and comprehensive tests to validate the new functionality and demonstrate its performance benefits. The implementation is clear and well-tested. I have one suggestion regarding the bucketing logic to make it more intuitive.
| load_shared_key, "WithLengthBucketKey").key | ||
|
|
||
| def _get_bucket(self, length): | ||
| return bisect.bisect_left(self._bucket_boundaries, length) |
There was a problem hiding this comment.
Using bisect.bisect_right would be more intuitive for defining bucket boundaries. Typically, boundaries are inclusive on the lower end and exclusive on the upper end (e.g., a bucket for lengths in [10, 50)). bisect.bisect_right achieves this behavior, which seems to be what was intended based on comments and logic in the tests.
With the current bisect.bisect_left, a length of 10 falls into bucket 0 (for lengths <= 10), which can be surprising. With bisect.bisect_right, a length of 10 would fall into bucket 1 (for lengths in [10, 50)).
If you make this change, please also update the assertions in test_length_bucket_assignment in util_test.py to match.
| return bisect.bisect_left(self._bucket_boundaries, length) | |
| return bisect.bisect_right(self._bucket_boundaries, length) |
| # bisect_left: length < 10 -> bucket 0, 10 <= length < 50 -> bucket 1, etc. | ||
| self.assertEqual(dofn._get_bucket(5), 0) | ||
| self.assertEqual(dofn._get_bucket(10), 0) | ||
| self.assertEqual(dofn._get_bucket(11), 1) | ||
| self.assertEqual(dofn._get_bucket(50), 1) | ||
| self.assertEqual(dofn._get_bucket(51), 2) | ||
| self.assertEqual(dofn._get_bucket(100), 2) | ||
| self.assertEqual(dofn._get_bucket(101), 3) | ||
| self.assertEqual(dofn._get_bucket(999), 3) |
There was a problem hiding this comment.
The comment here seems to describe the behavior of bisect.bisect_right, but the assertions match bisect.bisect_left, which is confusing.
If _get_bucket is changed to use bisect.bisect_right as suggested in my other comment, this test should be updated to reflect the new behavior. The boundaries would be inclusive on the lower end, which is more conventional.
Here is a suggested update for the test to align with bisect.bisect_right:
| # bisect_left: length < 10 -> bucket 0, 10 <= length < 50 -> bucket 1, etc. | |
| self.assertEqual(dofn._get_bucket(5), 0) | |
| self.assertEqual(dofn._get_bucket(10), 0) | |
| self.assertEqual(dofn._get_bucket(11), 1) | |
| self.assertEqual(dofn._get_bucket(50), 1) | |
| self.assertEqual(dofn._get_bucket(51), 2) | |
| self.assertEqual(dofn._get_bucket(100), 2) | |
| self.assertEqual(dofn._get_bucket(101), 3) | |
| self.assertEqual(dofn._get_bucket(999), 3) | |
| # bisect_right creates buckets where the boundary is the lower-inclusive bound. | |
| # e.g., for boundaries [10, 50], buckets are (-inf, 10), [10, 50), [50, inf) | |
| self.assertEqual(dofn._get_bucket(5), 0) | |
| self.assertEqual(dofn._get_bucket(10), 1) | |
| self.assertEqual(dofn._get_bucket(11), 1) | |
| self.assertEqual(dofn._get_bucket(50), 2) | |
| self.assertEqual(dofn._get_bucket(51), 2) | |
| self.assertEqual(dofn._get_bucket(100), 3) | |
| self.assertEqual(dofn._get_bucket(999), 3) |
d1fa315 to
cf2997a
Compare
Expands parameter documentation for clarity and replaces bisect_left with bisect_right to ensure bucket boundaries are inclusive on the lower bound. Updates util_test.py assertions accordingly.
cf2997a to
35a622e
Compare
[Stateful] Implement length-aware keying to minimize padding in BatchElements (Part 2/3)
Rationale
Issue: #37531 (Stateful Core - Part 2)
Part 1: #37532
This PR adds length-aware keying to BatchElements to improve batching efficiency for variable-length inputs (for example, NLP inference workloads).
Today, stateful BatchElements uses one shared key (WithSharedKey). That causes short and long sequences to be mixed in the same batch, so padding is dictated by the longest item and compute is wasted. This PR addresses that by routing elements into length buckets before stateful batching.
What changed
Testing and results
Added test_padding_efficiency_bimodal in util_test.py to represent a bimodal workload:
Observed result:
Interpretation:
Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
addresses #123), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, commentfixes #<ISSUE NUMBER>instead.