diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 223ad1672..d79da4057 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -29,6 +29,7 @@ from ...utils.profile_utils import TrainProfiler from . import checkpoint from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences +from .lora_utils import apply_lora_to_model, is_lora_model from .lr_scheduler import get_lr_scheduler from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor @@ -95,6 +96,9 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) + if self.args.lora_rank > 0 or self.args.lora_adapter_path: + model = apply_lora_to_model(model, self.args) + model.train() full_state = model.state_dict() @@ -108,11 +112,14 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + # Avoid "does not require grad" error + gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) if args.optimizer == "adam": + trainable_params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.AdamW( - self.model.parameters(), + trainable_params, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, diff --git a/slime/backends/fsdp_utils/arguments.py b/slime/backends/fsdp_utils/arguments.py index 441246071..fa8555fb1 100644 --- a/slime/backends/fsdp_utils/arguments.py +++ b/slime/backends/fsdp_utils/arguments.py @@ -60,6 +60,13 @@ class FSDPArgs: # YAML bookkeeping config: str | None = None + # LoRA configuration + lora_rank: int = 0 + lora_alpha: int = 16 + target_modules: str = "all-linear" + exclude_modules: str | None = None + lora_adapter_path: str | None = None + def parse_fsdp_cli(extra_args_provider=None): parser = argparse.ArgumentParser("FSDP SFT Training (slime)") diff --git a/slime/backends/fsdp_utils/checkpoint.py b/slime/backends/fsdp_utils/checkpoint.py index 3c49a10f8..4cf46605f 100644 --- a/slime/backends/fsdp_utils/checkpoint.py +++ b/slime/backends/fsdp_utils/checkpoint.py @@ -12,21 +12,34 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful +from slime.backends.fsdp_utils.lora_utils import is_lora_model + logger = logging.getLogger(__name__) class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model): + def __init__(self, model, lora_only: bool = False): self.model = model + self.lora_only = lora_only + self._key = "adapter" if lora_only else "model" def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) - return {"model": model_state_dict} + if self.lora_only: + model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k} + return {self._key: model_state_dict} def load_state_dict(self, state_dict): - set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None) + data = state_dict[self._key] + + if self.lora_only: + full_state_dict, _ = get_state_dict(self.model, optimizers=[]) + full_state_dict.update(data) + set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None) + else: + set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None) class OptimizerState(Stateful): @@ -103,20 +116,22 @@ def load(actor: Any) -> dict[str, Any] | None: model_dir = checkpoint_dir / "model" optimizer_dir = checkpoint_dir / "optimizer" lr_scheduler_dir = checkpoint_dir / "lr_scheduler" + lora_dir = checkpoint_dir / "adapter" + + lora_only = lora_dir.exists() and is_lora_model(actor.model) + model_dir = lora_dir if lora_only else model_dir if not model_dir.exists(): - logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") + logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.") return None - # Load model weights (always) - model_state = ModelState(actor.model) + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - try: dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir)) - logger.info(f"[FSDP] Loaded model from {model_dir}") + logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}") except Exception as e: - logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}") + logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}") return None # Load optimizer state (optional) @@ -210,9 +225,19 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - model_state = ModelState(actor.model) + lora_only = is_lora_model(actor.model) + if lora_only: + save_dir = checkpoint_dir / "adapter" + if dist.get_rank() == 0: + save_dir.mkdir(parents=True, exist_ok=True) + dist.barrier() + else: + save_dir = model_dir + + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - dcp.save(state_dict, checkpoint_id=str(model_dir)) + dcp.save(state_dict, checkpoint_id=str(save_dir)) + logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}") # Save optimizer state if hasattr(actor, "optimizer") and actor.optimizer is not None: diff --git a/slime/backends/fsdp_utils/lora_utils.py b/slime/backends/fsdp_utils/lora_utils.py new file mode 100644 index 000000000..ce49c75b4 --- /dev/null +++ b/slime/backends/fsdp_utils/lora_utils.py @@ -0,0 +1,76 @@ +import logging +import os +import shutil +from pathlib import Path + +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + +try: + from peft import LoraConfig, PeftModel, TaskType, get_peft_model +except ImportError as err: + raise ImportError("peft library required for LoRA. Install with: pip install peft") from err + +logger = logging.getLogger(__name__) + +LORA_READY_MARKER = ".lora_ready" +LORA_ADAPTER_NAME = "slime_lora" +LORA_SUBDIR = "tmp_lora" + + +def apply_lora_to_model(model: nn.Module, args) -> nn.Module: + if args.lora_adapter_path: + logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") + model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) + peft_config = model.peft_config["default"] + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + model.print_trainable_parameters() + return model + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + bias="none", + ) + + model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False) + model.print_trainable_parameters() + logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") + return model + + +def is_lora_model(module: nn.Module) -> bool: + unwrapped = getattr(module, "_fsdp_wrapped_module", module) + return hasattr(unwrapped, "peft_config") + + +def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: + """Save LoRA adapter to disk with file lock mechanism.""" + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + full_state_dict = get_model_state_dict(module, options=options) + + state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} + + if dist.get_rank() == 0: + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + module.save_pretrained(str(save_path), state_dict=state_dict) + + # TODO: check if file lock is needed or better way to do it + os.sync() + + logger.info(f"Saved LoRA adapter to {save_path}") + return save_dir + + +def delete_lora_from_disk(save_dir: str) -> None: + """Delete LoRA adapter files from disk.""" + save_path = Path(save_dir) + if save_path.exists(): + shutil.rmtree(save_path) + logger.info(f"Deleted LoRA adapter from {save_path}") diff --git a/slime/backends/fsdp_utils/update_weight_utils.py b/slime/backends/fsdp_utils/update_weight_utils.py index 4c0ce5478..78d3ba8b9 100644 --- a/slime/backends/fsdp_utils/update_weight_utils.py +++ b/slime/backends/fsdp_utils/update_weight_utils.py @@ -1,5 +1,6 @@ import abc import logging +import os import socket from argparse import Namespace from collections.abc import Sequence @@ -19,12 +20,12 @@ from slime.utils.distributed_utils import init_process_group - try: from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import] except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] +from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, delete_lora_from_disk, is_lora_model, save_lora_to_disk logger = logging.getLogger(__name__) @@ -34,6 +35,8 @@ def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model self.weight_version = 0 + self._lora_loaded = False + self._base_sync_done = False @abc.abstractmethod def connect_rollout_engines( @@ -45,31 +48,77 @@ def connect_rollout_engines( def update_weights(self) -> None: self.weight_version += 1 - bucket = [] - bucket_size = 0 - for name, param in self.model.state_dict().items(): - param_size = param.numel() * param.element_size() - if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: - self.wait_and_update_bucket_weights(bucket) - del bucket - bucket = [] - bucket_size = 0 - - param = param.cuda() - if isinstance(param, DTensor): - # async version of param.full_tensor - param = param.redistribute( - placements=[Replicate()] * param.device_mesh.ndim, - async_op=True, - ).to_local() - bucket.append((name, param)) - bucket_size += param_size - - if bucket: - self.wait_and_update_bucket_weights(bucket) - del bucket + + # Update base model if needed + # Level 1: only sync base once for LoRA models, then just LoRA + # Level 2: always sync base + LoRA + if not (is_lora_model(self.model) and self._base_sync_done and self.args.offload_rollout_level == 1): bucket = [] bucket_size = 0 + for name, param in self.model.state_dict().items(): + if any(x in name for x in ["_flat_param", "lora_"]): + continue + name = name.replace("base_model.model.", "").replace(".base_layer", "") + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: + self.wait_and_update_bucket_weights(bucket) + del bucket + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + # async version of param.full_tensor + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if bucket: + self.wait_and_update_bucket_weights(bucket) + del bucket + + self._base_sync_done = True + + # Update lora weights if needed + if is_lora_model(self.model): + self._update_lora_via_file() + + def _update_lora_via_file(self) -> None: + """Push LoRA weights to rollout engines using disk files.""" + self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) + if dist.get_rank() == 0: + if os.path.exists(self._lora_save_dir): + delete_lora_from_disk(self._lora_save_dir) + + dist.barrier() + + save_lora_to_disk(self.model, self._lora_save_dir) + + dist.barrier() + + if dist.get_rank() == 0: + if self._lora_loaded: + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + refs = [ + engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) + for engine in self.rollout_engines + ] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + self._lora_loaded = True + + dist.barrier() def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 421dd4e83..7782bbf4d 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -278,9 +278,15 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self): + def release_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage resume: weights, kv_cache + """ self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) def resume_memory_occupation(self, tags: list[str] = None): """ @@ -336,6 +342,18 @@ def update_weights_from_distributed( payload, ) + def load_lora_adapter(self, lora_name: str, lora_path: str): + return self._make_request( + "load_lora_adapter", + {"lora_name": lora_name, "lora_path": lora_path}, + ) + + def unload_lora_adapter(self, lora_name: str): + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + def pause_generation(self): response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) response.raise_for_status() @@ -419,6 +437,10 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" + if args.lora_rank > 0 or args.lora_adapter_path is not None: + kwargs["enable_lora"] = True + kwargs["max_lora_rank"] = args.lora_rank + kwargs["lora_target_modules"] = args.target_modules external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 9d3de1641..add5dfe09 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -125,8 +125,8 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): - return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + def offload(self, tags: list[str] = None): + return ray.get([engine.release_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index db9700481..5265c8595 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -12,6 +12,7 @@ from packaging.version import parse from tqdm import tqdm +from slime.backends.fsdp_utils.lora_utils import LORA_ADAPTER_NAME from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from slime.rollout.filter_hub.base_types import DynamicFilterOutput from slime.utils.async_utils import run @@ -126,6 +127,10 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + # Use LoRA adapter when LoRA is enabled + if args.lora_rank > 0 or args.lora_adapter_path is not None: + payload["lora_path"] = LORA_ADAPTER_NAME + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 964bf00d3..213b3ac6e 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -104,6 +104,15 @@ def add_cluster_arguments(parser): "This will always be true when --colocate is set." ), ) + parser.add_argument( + "--offload-rollout-level", + type=int, + default=2, + help=( + "The offload level for rollout when offload-rollout is set. " + "1 means only offload kv cache, 2 means offload kv cache and weights." + ), + ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -1417,6 +1426,27 @@ def slime_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + if args.lora_rank > 0: + # assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/train.py b/train.py index c7900f43e..85a33f59a 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def offload_train(): actor_model.clear_memory() def onload_rollout(): - if args.offload_rollout: + if args.offload_rollout and args.offload_rollout_level == 2: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # train loop. @@ -68,7 +68,16 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + # level 1: offload kv cache only, level 2: offload weights + kv cache + ray.get( + rollout_manager.offload.remote( + tags=( + [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] + if args.offload_rollout_level == 1 + else [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_CUDA_GRAPH] + ) + ) + ) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref)