-
Notifications
You must be signed in to change notification settings - Fork 3
[StreamingDataLoader, 1/N] feat: implement RankAwareSampler #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> Co-authored-by: ji-huazhong <hzji210@gmail.com> Co-authored-by: baymax591 <cbai@mail.nwpu.edu.cn> Co-authored-by: jianjunzhong <jianjunzhong@foxmail.com> Co-authored-by: LLLLxmmm <liuqianmeng@huawei.com> Co-authored-by: dpj135 <958208521@qq.com> Co-authored-by: Evelynn-V <liwenlin0223l@gmail.com> Co-authored-by: liujia7 <liujia7@xiaohongshu.com> Co-authored-by: 赵海源 <zhaohaiyuan@xiaohongshu.com> Co-authored-by: NINGBENZHE <ningbenzhe@xiaohongshu.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces RankAwareSampler, the first component of a planned StreamingDataLoader feature for distributed training. The sampler ensures deterministic data consumption across distributed data parallel (DP) ranks by guaranteeing that all ranks within the same DP group receive identical sample indices.
Changes:
- Implements
RankAwareSamplerclass with state management for coordinated sampling across DP ranks - Updates base sampler type annotations and documentation to accommodate the new sampler
- Adds comprehensive test suite for
RankAwareSamplerwith edge case coverage
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| transfer_queue/sampler/rank_aware_sampler.py | New sampler implementation with caching mechanism for DP-group coordination |
| transfer_queue/sampler/base.py | Updates type hints for _states dict and adds RankAwareSampler to documentation |
| transfer_queue/sampler/init.py | Exports RankAwareSampler for public API |
| transfer_queue/init.py | Adds RankAwareSampler and StreamDataLoader exports (StreamDataLoader module missing) |
| tests/test_samplers.py | Comprehensive test suite for RankAwareSampler with multiple scenarios |
Comments suppressed due to low confidence (1)
transfer_queue/sampler/base.py:54
- Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
def sample(
self,
ready_indexes: list[int],
batch_size: int,
*args: Any,
**kwargs: Any,
) -> tuple[list[int], list[int]]:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Check if this was the last rank in the DP group to fetch | ||
| if self._states[dp_group]["fetch_count"] >= fetches_per_batch: | ||
| del self._states[dp_group] | ||
|
|
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If some ranks fail to call sample() (e.g., due to crashes or network issues), the state for that dp_group will never be cleaned up and will remain in self._states indefinitely, causing a memory leak. Consider adding a timeout mechanism or state cleanup strategy for orphaned entries, or document this limitation in the class docstring.
| # Setup path | ||
| parent_dir = Path(__file__).resolve().parent.parent | ||
| sys.path.append(str(parent_dir)) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The manual sys.path manipulation is generally an anti-pattern in Python testing. Modern test runners like pytest automatically handle module discovery. This could cause issues in CI/CD environments or when running tests from different directories. Consider removing this manual path setup and relying on proper package installation (e.g., pip install -e .) or pytest's natural module discovery instead.
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> Co-authored-by: zhabuye <2947436155@qq.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Comments suppressed due to low confidence (1)
transfer_queue/sampler/base.py:53
- Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
def sample(
self,
ready_indexes: list[int],
batch_size: int,
*args: Any,
**kwargs: Any,
) -> tuple[list[int], list[int]]:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Comments suppressed due to low confidence (1)
transfer_queue/sampler/base.py:53
- Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'n_samples_per_prompt'. Overriding method method GRPOGroupNSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_group'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'dp_world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
Overridden method signature does not match call, where it is passed an argument named 'world_size'. Overriding method method RankAwareSampler.sample matches the call.
def sample(
self,
ready_indexes: list[int],
batch_size: int,
*args: Any,
**kwargs: Any,
) -> tuple[list[int], list[int]]:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
…fetch in RankAwareSampler (#7) ## Background In the initial implementation introduced in PR #4, `RankAwareSampler` allowed individual ranks to fetch `BatchMeta` from `TransferQueueController`, guaranteeing all ranks within the same data replica group receive the same sample indices.. However, this implementation had two main limitations: - It did not account for asynchronous calls arising from different tasks in a task-separated RL framework. - It did not support data pre-fetching when integrated with the `StreamingDataLoader` interface. ## Solution This PR enhances `RankAwareSampler` to support multi-task concurrency and data pre-fetching: - **Task & Partition Awareness**: Introduced `task_name` and `partition_id` parameters to correctly identify the current task context and apply distinct caching logic for each task. - **Pre-fetching Support**: Implemented a dynamic buffer for each rank under each task. CC: @NINGBENZHE --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Background
This PR is the first in a series [1/N] to introduce
StreamingDataLoader, a mechanism designed to optimize data dispatch in distributed training.Specifically, this PR implements
RankAwareSampler. In distributed data parallel (DP) scenarios where ranks retrieve data independently, this sampler ensures deterministic behavior: it guarantees that all ranks within the same DP group receive identical sample indices, synchronizing the data consumption process.Leveraging
StreamingDataLoader, we can supports micro-batch level pipelining for training backends. By passing the dataloader instance directly intoforward_backward_func, we avoid the bottleneck of retrieving full mini-batches in advance. This allows for highly efficient, fine-grained streaming throughout the training process.Please refer to our roadmap for more details: [Roadmap] StreamingDataLoader for task-separated RL post-training
Note
We have added
Co-authored-bycredits to the commit messages to properly attribute the work to the early developers from https://github.com/TransferQueue/TransferQueue.