Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...utils.profile_utils import TrainProfiler
from . import checkpoint
from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences
from .lora_utils import apply_lora_to_model, is_lora_model
from .lr_scheduler import get_lr_scheduler
from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor

Expand Down Expand Up @@ -95,6 +96,9 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
attn_implementation=self.args.attn_implementation,
)

if self.args.lora_rank > 0 or self.args.lora_adapter_path:
model = apply_lora_to_model(model, self.args)

model.train()

full_state = model.state_dict()
Expand All @@ -108,11 +112,14 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
self.model = model

if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
# Avoid "does not require grad" error
gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {}
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs)

if args.optimizer == "adam":
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
trainable_params,
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps,
Expand Down
7 changes: 7 additions & 0 deletions slime/backends/fsdp_utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ class FSDPArgs:
# YAML bookkeeping
config: str | None = None

# LoRA configuration
lora_rank: int = 0
lora_alpha: int = 16
target_modules: str = "all-linear"
exclude_modules: str | None = None
lora_adapter_path: str | None = None


def parse_fsdp_cli(extra_args_provider=None):
parser = argparse.ArgumentParser("FSDP SFT Training (slime)")
Expand Down
47 changes: 36 additions & 11 deletions slime/backends/fsdp_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,34 @@
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

from slime.backends.fsdp_utils.lora_utils import is_lora_model

logger = logging.getLogger(__name__)


class ModelState(Stateful):
"""Wrapper for model state only."""

def __init__(self, model):
def __init__(self, model, lora_only: bool = False):
self.model = model
self.lora_only = lora_only
self._key = "adapter" if lora_only else "model"

def state_dict(self):
model_state_dict, _ = get_state_dict(self.model, optimizers=[])
return {"model": model_state_dict}
if self.lora_only:
model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k}
return {self._key: model_state_dict}

def load_state_dict(self, state_dict):
set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None)
data = state_dict[self._key]

if self.lora_only:
full_state_dict, _ = get_state_dict(self.model, optimizers=[])
full_state_dict.update(data)
set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None)
else:
set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None)


class OptimizerState(Stateful):
Expand Down Expand Up @@ -103,20 +116,22 @@ def load(actor: Any) -> dict[str, Any] | None:
model_dir = checkpoint_dir / "model"
optimizer_dir = checkpoint_dir / "optimizer"
lr_scheduler_dir = checkpoint_dir / "lr_scheduler"
lora_dir = checkpoint_dir / "adapter"

lora_only = lora_dir.exists() and is_lora_model(actor.model)
model_dir = lora_dir if lora_only else model_dir

if not model_dir.exists():
logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.")
logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.")
return None

# Load model weights (always)
model_state = ModelState(actor.model)
model_state = ModelState(actor.model, lora_only=lora_only)
state_dict = {"model_state": model_state}

try:
dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir))
logger.info(f"[FSDP] Loaded model from {model_dir}")
logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}")
except Exception as e:
logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}")
logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}")
return None

# Load optimizer state (optional)
Expand Down Expand Up @@ -210,9 +225,19 @@ def save(actor: Any, iteration: int) -> None:
dist.barrier()

# Save model weights
model_state = ModelState(actor.model)
lora_only = is_lora_model(actor.model)
if lora_only:
save_dir = checkpoint_dir / "adapter"
if dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
dist.barrier()
else:
save_dir = model_dir

model_state = ModelState(actor.model, lora_only=lora_only)
state_dict = {"model_state": model_state}
dcp.save(state_dict, checkpoint_id=str(model_dir))
dcp.save(state_dict, checkpoint_id=str(save_dir))
logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}")

# Save optimizer state
if hasattr(actor, "optimizer") and actor.optimizer is not None:
Expand Down
76 changes: 76 additions & 0 deletions slime/backends/fsdp_utils/lora_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
import os
import shutil
from pathlib import Path

import torch.distributed as dist
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict

try:
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
except ImportError as err:
raise ImportError("peft library required for LoRA. Install with: pip install peft") from err

logger = logging.getLogger(__name__)

LORA_READY_MARKER = ".lora_ready"
LORA_ADAPTER_NAME = "slime_lora"
LORA_SUBDIR = "tmp_lora"


def apply_lora_to_model(model: nn.Module, args) -> nn.Module:
if args.lora_adapter_path:
logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}")
model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True)
peft_config = model.peft_config["default"]
if isinstance(peft_config.task_type, str):
peft_config.task_type = TaskType.CAUSAL_LM
model.print_trainable_parameters()
return model

lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=args.target_modules,
bias="none",
)

model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False)
model.print_trainable_parameters()
logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}")
return model


def is_lora_model(module: nn.Module) -> bool:
unwrapped = getattr(module, "_fsdp_wrapped_module", module)
return hasattr(unwrapped, "peft_config")


def save_lora_to_disk(module: nn.Module, save_dir: str) -> str:
"""Save LoRA adapter to disk with file lock mechanism."""
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
full_state_dict = get_model_state_dict(module, options=options)

state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name}

if dist.get_rank() == 0:
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)

module.save_pretrained(str(save_path), state_dict=state_dict)

# TODO: check if file lock is needed or better way to do it
os.sync()

logger.info(f"Saved LoRA adapter to {save_path}")
return save_dir


def delete_lora_from_disk(save_dir: str) -> None:
"""Delete LoRA adapter files from disk."""
save_path = Path(save_dir)
if save_path.exists():
shutil.rmtree(save_path)
logger.info(f"Deleted LoRA adapter from {save_path}")
97 changes: 73 additions & 24 deletions slime/backends/fsdp_utils/update_weight_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import logging
import os
import socket
from argparse import Namespace
from collections.abc import Sequence
Expand All @@ -19,12 +20,12 @@

from slime.utils.distributed_utils import init_process_group


try:
from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import]
except ImportError:
from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import]

from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, delete_lora_from_disk, is_lora_model, save_lora_to_disk

logger = logging.getLogger(__name__)

Expand All @@ -34,6 +35,8 @@ def __init__(self, args: Namespace, model: torch.nn.Module) -> None:
self.args = args
self.model = model
self.weight_version = 0
self._lora_loaded = False
self._base_sync_done = False

@abc.abstractmethod
def connect_rollout_engines(
Expand All @@ -45,31 +48,77 @@ def connect_rollout_engines(

def update_weights(self) -> None:
self.weight_version += 1
bucket = []
bucket_size = 0
for name, param in self.model.state_dict().items():
param_size = param.numel() * param.element_size()
if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size:
self.wait_and_update_bucket_weights(bucket)
del bucket
bucket = []
bucket_size = 0

param = param.cuda()
if isinstance(param, DTensor):
# async version of param.full_tensor
param = param.redistribute(
placements=[Replicate()] * param.device_mesh.ndim,
async_op=True,
).to_local()
bucket.append((name, param))
bucket_size += param_size

if bucket:
self.wait_and_update_bucket_weights(bucket)
del bucket

# Update base model if needed
# Level 1: only sync base once for LoRA models, then just LoRA
# Level 2: always sync base + LoRA
if not (is_lora_model(self.model) and self._base_sync_done and self.args.offload_rollout_level == 1):
bucket = []
bucket_size = 0
for name, param in self.model.state_dict().items():
if any(x in name for x in ["_flat_param", "lora_"]):
continue
name = name.replace("base_model.model.", "").replace(".base_layer", "")
param_size = param.numel() * param.element_size()
if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size:
self.wait_and_update_bucket_weights(bucket)
del bucket
bucket = []
bucket_size = 0

param = param.cuda()
if isinstance(param, DTensor):
# async version of param.full_tensor
param = param.redistribute(
placements=[Replicate()] * param.device_mesh.ndim,
async_op=True,
).to_local()
bucket.append((name, param))
bucket_size += param_size

if bucket:
self.wait_and_update_bucket_weights(bucket)
del bucket

self._base_sync_done = True

# Update lora weights if needed
if is_lora_model(self.model):
self._update_lora_via_file()

def _update_lora_via_file(self) -> None:
"""Push LoRA weights to rollout engines using disk files."""
self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR)
if dist.get_rank() == 0:
if os.path.exists(self._lora_save_dir):
delete_lora_from_disk(self._lora_save_dir)

dist.barrier()

save_lora_to_disk(self.model, self._lora_save_dir)

dist.barrier()

if dist.get_rank() == 0:
if self._lora_loaded:
refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines]
ray.get(refs)

refs = [engine.flush_cache.remote() for engine in self.rollout_engines]
ray.get(refs)

refs = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious how long it take for SGLang to read LoRA weights from disk. Is it possible to pass through NCCL?I am not sure about the file size

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we can not assume that in distributed training, there is a shared file system for every node. So the read from disk approach may not work here

engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir)
for engine in self.rollout_engines
]
ray.get(refs)

refs = [engine.flush_cache.remote() for engine in self.rollout_engines]
ray.get(refs)

self._lora_loaded = True

dist.barrier()

def wait_and_update_bucket_weights(self, bucket):
bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket]
Expand Down
26 changes: 24 additions & 2 deletions slime/backends/sglang_utils/sglang_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,15 @@ def get_weight_version(self):
response.raise_for_status()
return response.json()["weight_version"]

def release_memory_occupation(self):
def release_memory_occupation(self, tags: list[str] = None):
"""
Available tags for multi-stage resume: weights, kv_cache
"""
self.flush_cache()
return self._make_request("release_memory_occupation")
return self._make_request(
"release_memory_occupation",
{"tags": tags},
)

def resume_memory_occupation(self, tags: list[str] = None):
"""
Expand Down Expand Up @@ -336,6 +342,18 @@ def update_weights_from_distributed(
payload,
)

def load_lora_adapter(self, lora_name: str, lora_path: str):
return self._make_request(
"load_lora_adapter",
{"lora_name": lora_name, "lora_path": lora_path},
)

def unload_lora_adapter(self, lora_name: str):
return self._make_request(
"unload_lora_adapter",
{"lora_name": lora_name},
)

def pause_generation(self):
response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={})
response.raise_for_status()
Expand Down Expand Up @@ -419,6 +437,10 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work
kwargs["enable_return_routed_experts"] = True
if args.fp16:
kwargs["dtype"] = "float16"
if args.lora_rank > 0 or args.lora_adapter_path is not None:
kwargs["enable_lora"] = True
kwargs["max_lora_rank"] = args.lora_rank
kwargs["lora_target_modules"] = args.target_modules

external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS]

Expand Down
Loading