-
Notifications
You must be signed in to change notification settings - Fork 3
[StreamingDataLoader, 2/N] feat: support async sampling and data pre-fetch in RankAwareSampler #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request refactors the RankAwareSampler to support async sampling and data pre-fetch for task-separated RL training scenarios. The PR introduces a hierarchical state structure organized by partition_id, task_name, and data_replica_group to enable better coordination across distributed ranks and support prefetching behavior.
Changes:
- Renamed parameters from
dp_group,dp_world_size,world_sizetodata_replica_group,data_replica_rank,data_replica_world_size - Added new
task_nameandpartition_idparameters to support multi-task scenarios - Restructured internal state to support per-rank buffering for data prefetch
- Updated all tests to use the new API
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| transfer_queue/sampler/rank_aware_sampler.py | Refactored sample method signature and internal state management to support task-aware, partition-aware, and rank-aware caching with prefetch capabilities |
| tests/test_samplers.py | Updated all test cases to use new API parameters and added new tests for data prefetch and multi-task scenarios |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if partition_id not in self._states: | ||
| self._states[partition_id] = {} | ||
|
|
||
| # Calculate how many times this batch should be fetched across all ranks | ||
| if dp_world_size <= 0 or world_size % dp_world_size != 0: | ||
| raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})") | ||
| if task_name not in self._states[partition_id]: | ||
| self._states[partition_id][task_name] = {} | ||
|
|
||
| fetches_per_batch = world_size // dp_world_size | ||
| if data_replica_group not in self._states[partition_id][task_name]: | ||
| self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)} |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state dictionary grows unbounded as new partition_id, task_name, and data_replica_group combinations are added, but there's no cleanup mechanism. In long-running processes, this could lead to memory growth, especially if partition_ids or task_names are dynamically generated. Consider implementing a cleanup mechanism for completed groups or adding documentation about expected usage patterns to prevent memory issues.
| if data_replica_group not in self._states[partition_id][task_name]: | ||
| self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)} |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation does not validate that data_replica_world_size is consistent across calls for the same data_replica_group within a partition/task. If different ranks call the sample method with different data_replica_world_size values for the same group, the cache dictionary at line 131 will be initialized with different sizes, potentially causing IndexError or incorrect behavior. Consider adding validation to ensure all ranks in the same group use the same data_replica_world_size value.
tests/test_samplers.py
Outdated
| ) | ||
| assert sampled_rank1_time0 == [0, 1] | ||
| assert consumed_rank1_time0 == [0, 1] | ||
| ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0] |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable ready_indexes is not used.
| ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| fetches_per_batch = world_size // dp_world_size | ||
| if data_replica_group not in self._states[partition_id][task_name]: | ||
| self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)} |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a data_replica_group is first initialized, the state is created with a dictionary of empty lists for each rank (line 133). However, if the same data_replica_group is accessed again with a different data_replica_world_size value, the existing state is reused without validation. This could lead to KeyError if data_replica_rank is outside the originally initialized range, or silent bugs if the world size truly changed. Consider adding validation to check that data_replica_world_size matches the existing state structure when reusing a data_replica_group, or document that data_replica_world_size must remain constant for each data_replica_group throughout the sampler's lifetime.
| self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)} | |
| # Initialize per-rank cache for this data replica group. | |
| # World size is fixed for the lifetime of a given (partition_id, task_name, data_replica_group). | |
| self._states[partition_id][task_name][data_replica_group] = { | |
| i: [] for i in range(data_replica_world_size) | |
| } | |
| else: | |
| # Validate that the existing state structure matches the provided world size. | |
| group_state = self._states[partition_id][task_name][data_replica_group] | |
| existing_world_size = len(group_state) | |
| if existing_world_size != data_replica_world_size: | |
| raise ValueError( | |
| "Inconsistent data_replica_world_size for data_replica_group " | |
| f"{data_replica_group}: existing world size is {existing_world_size}, " | |
| f"but received {data_replica_world_size}. The world size must remain " | |
| "constant for each data_replica_group throughout the sampler's lifetime." | |
| ) |
| # Cache the sampled indices for other ranks in this data replica group | ||
| for i in range(data_replica_world_size): | ||
| if i != data_replica_rank: | ||
| self._states[partition_id][task_name][data_replica_group][i].append(sampled_indexes) |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new pre-fetching implementation can lead to unbounded memory growth if ranks sample at different rates. When a rank samples (line 135-147), it appends sampled indices to all other ranks' buffers (line 147). If one rank samples much more frequently than others, the slower ranks' buffers will grow without bound. Consider adding a maximum buffer size limit per rank or implementing a mechanism to detect and handle such imbalances, especially for long-running training jobs with asynchronous task execution patterns.
Background
In the initial implementation introduced in PR #4,
RankAwareSamplerallowed individual ranks to fetchBatchMetafromTransferQueueController, guaranteeing all ranks within the same data replica group receive the same sample indices.. However, this implementation had two main limitations:StreamingDataLoaderinterface.Solution
This PR enhances
RankAwareSamplerto support multi-task concurrency and data pre-fetching:task_nameandpartition_idparameters to correctly identify the current task context and apply distinct caching logic for each task.CC: @NINGBENZHE