Skip to content

[Roadmap] StreamingDataLoader for task-separated RL post-training #1

@0oshowero0

Description

@0oshowero0

Background

In the current single-controller implementation (e.g., verl), we rely on a single process to schedule the entire data flow in the cluster. This process decides which data should be computed by which device, acting as a data router throughout the training process. By integrating the current TransferQueue into the single-controller, we enable the single-controller to only dispatch BatchMeta instead of actual data, which greatly reduces the burden on the single-controller.

However, in real-world applications, the execution time of each device may vary significantly due to unbalanced response lengths or the "hardware lottery" (random performance variations across hardware). Pre-allocating training samples via the single-controller may lead to straggler issues in large-scale post-training.

Solution

As described in AsyncFlow and verl/discussions/2662, we propose a possible solution to reduce the synchronization between ranks by delegating data retrieval authority.

Image
  1. The single-controller starts all the RL tasks asynchronously. These tasks will request input data from TransferQueue.
  2. Feed the first global batch of prompts into TransferQueue.
  3. RL tasks successfully retrieve the required data from TransferQueue by themselves, and initiate their computational workloads. The pipeline is automatically formed without any human effort.
  4. The training logic detects the completion of the current iteration, records performance metrics, evicts the current global batch, and puts new prompts into TransferQueue.

In the figure above, the single-controller now only starts all tasks. Within each task, devices decide whether they need more data according to themselves. Therefore, faster devices can dynamically request more data, thus improving overall system efficiency.

More importantly, StreamingDataLoader supports micro-batch level pipelining for training backends. By passing the loader instance directly into forward_backward_func, we avoid the bottleneck of retrieving full mini-batches in advance. This allows for highly efficient, fine-grained streaming throughout the training process.

        data_iter = StreamingDataLoader()
        losses_reduced = self.forward_backward_func(
            forward_step_func=forward_step,
            data_iterator=data_iter,
            model=self.model,
            num_microbatches=num_microbatches,
            seq_length=self.seq_length,
            micro_batch_size=self.micro_batch_size,
            forward_only=forward_only,
            collect_non_loss_data=forward_only,
        )

Another great benefit lies in this abstraction is the duplicated data retrieval inside DP. Now all the ranks collect its data from TQ directly. Through StreamingDataLoader abstraction, we can let rank0 to retrieve data from TQ, and let it broadcast the data to other ranks.

Under this design, the single-controller will be structured as follows:

class RayGRPOTrainer():
    def fit(self, data_iters):
        iteration = 0
        # 1. Start all the RL tasks. Each task tries to pull data from TransferQueue by themselves.
        
        # generate sequences
        self.rollout_worker.start(iteration, self.train_iters)
        # compute reference log_prob
        self.ref_worker.start(iteration, self.train_iters)
        # compute rm scores.
        rule_reward = []
        for reward_worker in self.reward_list:
            if isinstance(reward_worker, RayActorGroup):
                reward_worker.start(iteration, self.train_iters)
            else:
                rule_reward.append(reward_worker.start.remote(iteration, self.train_iters))
        # compute advantage
        for advantage in self.advantage_list:
            advantage.start.remote(iteration, self.train_iters)
        # compute old log_prob
        if self.actor_fwd_worker:
            self.actor_fwd_worker.start(iteration, self.train_iters)
        # update actor
        self.actor_worker.start(iteration, self.train_iters)
    
        # 2. Put the first global batch of prompts into TransferQueue
        start_iter_time = time.time()
        for _ in range(self.staleness_threshold):
            batch = next(data_iters)
            put_prompts(batch, self.n_samples_per_prompt, total_data_rows)
    
        # 3. Iteration control
        while True:
            train_iteration = self.iteration_record.get() # iteration_record can be a Queue
    
            # Do metric update
            ...
    
            # Clear the corresponding TransferQueue
            self.clear_tq_controllers(train_iteration, indexes)
    
            # Put new prompts into TransferQueue
            if data_loader_index < self.train_iters:
                batch = next(data_iters)
                put_prompts(self.metrics_tq, batch, self.n_samples_per_prompt, self.dataset_additional_keys, indexes[0])
                data_loader_index += 1
            iteration += 1
    
            # Stop training
            if iteration >= self.train_iters:
                logger.info(f"The threshold of train iteration:{self.train_iters} is reached, stop putting prompt to TQ")
                ray.shutdown()
                break
        

To simplify usage, we aim to provide a StreamingDataLoader interface to directly tackles the data requirement from each rank. It encapsulates nearly all interaction logic with TransferQueue and automatically ensures each rank retrieves the correct data under different parallelism configurations.

class RolloutWorker(BaseWorker):
    def generate_sequences(self):
        data_fields = ['prompts', 'prompt_length']
        data_loader = self.create_stream_data_loader(
            task_name='actor_rollout',
            data_fields=data_fields,
            batch_size=self.config.batch_size,
        )
        data_iter = iter(data_loader)
    
        for batch_data, batch_meta in data_iter:
            prompts_data = batch_data['prompts']
    
            # Do Inference
            responses = self.rollout.generate_sequences(prompts_data)
    
            # write results back to TransferQueue
            self.tq_client.async_put(responses, batch_meta)

We should note that this implementation may incur extra debugging efforts. Thus, it is more suitable for large-scale training rather than algorithm development.

Please let us know if you are interested in this paradigm~

TODO

  • Optimize TransferQueueController workflow so that one get_meta request will not block others
  • Provide RankAwareSampler that can be used for distributed sampling
  • Provide StreamingDataLoader abstraction

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions