Skip to content

Conversation

@0oshowero0
Copy link
Collaborator

@0oshowero0 0oshowero0 commented Jan 22, 2026

Background

In the initial implementation introduced in PR #4, RankAwareSampler allowed individual ranks to fetch BatchMeta from TransferQueueController, guaranteeing all ranks within the same data replica group receive the same sample indices.. However, this implementation had two main limitations:

  • It did not account for asynchronous calls arising from different tasks in a task-separated RL framework.
  • It did not support data pre-fetching when integrated with the StreamingDataLoader interface.

Solution

This PR enhances RankAwareSampler to support multi-task concurrency and data pre-fetching:

  • Task & Partition Awareness: Introduced task_name and partition_id parameters to correctly identify the current task context and apply distinct caching logic for each task.
  • Pre-fetching Support: Implemented a dynamic buffer for each rank under each task.

CC: @NINGBENZHE

Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Copilot AI review requested due to automatic review settings January 22, 2026 11:03
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request refactors the RankAwareSampler to support async sampling and data pre-fetch for task-separated RL training scenarios. The PR introduces a hierarchical state structure organized by partition_id, task_name, and data_replica_group to enable better coordination across distributed ranks and support prefetching behavior.

Changes:

  • Renamed parameters from dp_group, dp_world_size, world_size to data_replica_group, data_replica_rank, data_replica_world_size
  • Added new task_name and partition_id parameters to support multi-task scenarios
  • Restructured internal state to support per-rank buffering for data prefetch
  • Updated all tests to use the new API

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 10 comments.

File Description
transfer_queue/sampler/rank_aware_sampler.py Refactored sample method signature and internal state management to support task-aware, partition-aware, and rank-aware caching with prefetch capabilities
tests/test_samplers.py Updated all test cases to use new API parameters and added new tests for data prefetch and multi-task scenarios

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +124 to +131
if partition_id not in self._states:
self._states[partition_id] = {}

# Calculate how many times this batch should be fetched across all ranks
if dp_world_size <= 0 or world_size % dp_world_size != 0:
raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})")
if task_name not in self._states[partition_id]:
self._states[partition_id][task_name] = {}

fetches_per_batch = world_size // dp_world_size
if data_replica_group not in self._states[partition_id][task_name]:
self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)}
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state dictionary grows unbounded as new partition_id, task_name, and data_replica_group combinations are added, but there's no cleanup mechanism. In long-running processes, this could lead to memory growth, especially if partition_ids or task_names are dynamically generated. Consider implementing a cleanup mechanism for completed groups or adding documentation about expected usage patterns to prevent memory issues.

Copilot uses AI. Check for mistakes.
Comment on lines +130 to +131
if data_replica_group not in self._states[partition_id][task_name]:
self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)}
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation does not validate that data_replica_world_size is consistent across calls for the same data_replica_group within a partition/task. If different ranks call the sample method with different data_replica_world_size values for the same group, the cache dictionary at line 131 will be initialized with different sizes, potentially causing IndexError or incorrect behavior. Consider adding validation to ensure all ranks in the same group use the same data_replica_world_size value.

Copilot uses AI. Check for mistakes.
)
assert sampled_rank1_time0 == [0, 1]
assert consumed_rank1_time0 == [0, 1]
ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0]
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable ready_indexes is not used.

Suggested change
ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0]

Copilot uses AI. Check for mistakes.
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@0oshowero0 0oshowero0 changed the title [StreamingDataLoader, 1/N] feat: support async sampling and data pre-fetch in RankAwareSampler [StreamingDataLoader, 2/N] feat: support async sampling and data pre-fetch in RankAwareSampler Jan 22, 2026
@0oshowero0 0oshowero0 requested a review from Copilot January 22, 2026 11:53
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


fetches_per_batch = world_size // dp_world_size
if data_replica_group not in self._states[partition_id][task_name]:
self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)}
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a data_replica_group is first initialized, the state is created with a dictionary of empty lists for each rank (line 133). However, if the same data_replica_group is accessed again with a different data_replica_world_size value, the existing state is reused without validation. This could lead to KeyError if data_replica_rank is outside the originally initialized range, or silent bugs if the world size truly changed. Consider adding validation to check that data_replica_world_size matches the existing state structure when reusing a data_replica_group, or document that data_replica_world_size must remain constant for each data_replica_group throughout the sampler's lifetime.

Suggested change
self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)}
# Initialize per-rank cache for this data replica group.
# World size is fixed for the lifetime of a given (partition_id, task_name, data_replica_group).
self._states[partition_id][task_name][data_replica_group] = {
i: [] for i in range(data_replica_world_size)
}
else:
# Validate that the existing state structure matches the provided world size.
group_state = self._states[partition_id][task_name][data_replica_group]
existing_world_size = len(group_state)
if existing_world_size != data_replica_world_size:
raise ValueError(
"Inconsistent data_replica_world_size for data_replica_group "
f"{data_replica_group}: existing world size is {existing_world_size}, "
f"but received {data_replica_world_size}. The world size must remain "
"constant for each data_replica_group throughout the sampler's lifetime."
)

Copilot uses AI. Check for mistakes.
Comment on lines +144 to +147
# Cache the sampled indices for other ranks in this data replica group
for i in range(data_replica_world_size):
if i != data_replica_rank:
self._states[partition_id][task_name][data_replica_group][i].append(sampled_indexes)
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new pre-fetching implementation can lead to unbounded memory growth if ranks sample at different rates. When a rank samples (line 135-147), it appends sampled indices to all other ranks' buffers (line 147). If one rank samples much more frequently than others, the slower ranks' buffers will grow without bound. Consider adding a maximum buffer size limit per rank or implementing a mechanism to detect and handle such imbalances, especially for long-running training jobs with asynchronous task execution patterns.

Copilot uses AI. Check for mistakes.
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@0oshowero0 0oshowero0 merged commit 45cfafe into Ascend:main Jan 22, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant