This document provides detailed information about the LEMA (Layer-wise Efficient Memory Abstraction) library API.
The primary entry point for the framework. It orchestrates memory management, adapters, and LoRA parameters.
Initializes the model using a LemaConfig object.
Returns a LemaTrainer instance pre-configured with this model's components and memory manager.
Pre-initializes all LoRA adapters. Must be called before get_trainable_parameters() for new models.
Returns a list of all trainable parameters (LoRA weights) managed by the model.
Saves the configuration and LoRA adapter weights.
Loads a LEMA model from a directory containing lema_config.json and adapter_model.bin.
Configuration dataclass for LEMA.
| Parameter | Type | Default | Description |
|---|---|---|---|
model_name_or_path |
str |
Required | HuggingFace ID or path to model directory. |
model_type |
str |
None |
llama or gpt2. Auto-detected if None. |
gbi_path |
str |
None |
Path to the .safetensors file. |
device |
str |
"cuda" |
Execution device. |
strategy |
MemoryStrategy |
STREAMING |
STREAMING or RESIDENT. |
save_steps |
int |
500 |
Steps between automatic checkpoints. |
output_dir |
str |
"output" |
Directory for automatic checkpoints. |
lora_rank |
int |
16 |
LoRA rank (r). |
lora_alpha |
int |
32 |
LoRA alpha. |
learning_rate |
float |
1e-4 |
Learning rate. |
gradient_checkpointing |
bool |
False |
Enable to save activation VRAM. |
Orchestrates the training loop with layer-swapping logic.
Low-level constructor. Preferred usage is via LemaModel.get_trainer().
Executes one forward and backward pass. Tracks global_step and triggers auto-checkpointing.
- Returns:
(logits, loss_value).
Saves the model state, configuration, and optimizer state.