Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,141 @@ def eval_from_yaml(
yaml_cfg = om.merge(yaml_cfg, cli_cfg)
assert isinstance(yaml_cfg, DictConfig)
return evaluate(yaml_cfg)


def convert_peft_adapter_format(model_dir: str) -> None:
"""Convert PEFT adapter from safetensors to bin format to avoid device metadata issues.

This function performs three operations:
1. Converts the adapter weights from safetensors to PyTorch .bin format
2. Renames the original safetensors file to .safetensors.bak
3. Updates the adapter_config.json to reference .bin files instead of .safetensors

Args:
model_dir: Full path to the model directory containing PEFT adapter files.
This should be the directory containing:
- adapter_config.json
- adapter_model.safetensors
Example: '/model-checkpoints/llama3-1b-lora-20250420_180800'

Returns:
None

Side Effects:
- Creates adapter_model.bin in model_dir
- Renames adapter_model.safetensors to adapter_model.safetensors.bak
- Modifies adapter_config.json to reference .bin files
"""
import torch
import json
import os

# Paths for the adapter files
adapter_path = os.path.join(model_dir, "adapter_model.safetensors")
bin_adapter_path = os.path.join(model_dir, "adapter_model.bin")
config_path = os.path.join(model_dir, "adapter_config.json")

try:
# Load and convert if needed
if os.path.exists(adapter_path) and not os.path.exists(bin_adapter_path):
# Load safetensors adapter with explicit CPU device
from safetensors.torch import load_file
weights = load_file(adapter_path, device="cpu")

# Save as PyTorch bin format
torch.save(weights, bin_adapter_path)
print(f"Converted adapter to .bin format: {bin_adapter_path}")

# Rename/move safetensors file to force bin usage
if os.path.exists(adapter_path):
backup_path = os.path.join(model_dir, "adapter_model.safetensors.bak")
os.rename(adapter_path, backup_path)
print(f"Moved safetensors file to {backup_path} to force bin usage")

# Update config to reference .bin file
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)

# Update config to use bin file
weight_map = config.get("weight_map", {})
for key in weight_map:
if "safetensors" in weight_map[key]:
weight_map[key] = weight_map[key].replace("safetensors", "bin")

# Also update model_type if needed
if "safetensors" in config.get("model_type", ""):
config["model_type"] = config["model_type"].replace("safetensors", "bin")

with open(config_path, 'w') as f:
json.dump(config, f, indent=2)

print(f"Updated adapter config to use .bin format")
except Exception as e:
print(f"Failed to convert adapter format: {e}")


def restore_safetensors_after_eval(model_dir: str) -> None:
"""Restore safetensor files to their original state after evaluation.

This function reverses the changes made by convert_peft_adapter_format():
1. Restores the original adapter_model.safetensors from .bak file if it exists
2. Updates the adapter_config.json to reference .safetensors again
3. Keeps the .bin file in place for potential future use

Args:
model_dir: Full path to the model directory containing PEFT adapter files.
This should be the directory containing:
- adapter_config.json
- adapter_model.bin
- adapter_model.safetensors.bak (created by convert_peft_adapter_format)
Example: '/model-checkpoints/llama3-1b-lora-20250420_180800'

Returns:
None

Side Effects:
- Restores adapter_model.safetensors from the .bak file if it exists
- Modifies adapter_config.json to reference .safetensors files
- Keeps adapter_model.bin for potential future use
"""
import os
import json

# Paths for the adapter files
backup_path = os.path.join(model_dir, "adapter_model.safetensors.bak")
adapter_path = os.path.join(model_dir, "adapter_model.safetensors")
config_path = os.path.join(model_dir, "adapter_config.json")

# Only restore if backup exists
if os.path.exists(backup_path):
if os.path.exists(adapter_path):
print(f"Safetensors file already exists at {adapter_path}, skipping restore")
else:
os.rename(backup_path, adapter_path)
print(f"Restored safetensors file from backup")

# Update config only if needed
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)

# Check if config needs updating
needs_update = False
weight_map = config.get("weight_map", {})

for key in weight_map:
if "bin" in weight_map[key]:
weight_map[key] = weight_map[key].replace("bin", "safetensors")
needs_update = True

if "bin" in config.get("model_type", ""):
config["model_type"] = config["model_type"].replace("bin", "safetensors")
needs_update = True

if needs_update:
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
print(f"Updated adapter config to use safetensors format")
else:
print(f"No backup found at {backup_path}, nothing to restore")