Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,12 +1344,15 @@ std::string get_model_endpoint() {
}

void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_clear_adapter_lora(ctx);
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_set_adapter_lora(ctx, la.ptr, la.scale);
}
std::vector<llama_adapter_lora*> loras;
std::vector<float> scales;

for (auto & la: lora) {
loras.push_back(la.ptr);
scales.push_back(la.scale);
}

llama_put_adapter_loras(ctx, loras.size(), loras.data(), scales.data());
}

struct llama_model_params common_model_params_to_llama(common_params & params) {
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,9 @@ extern "C" {
// Remove all LoRA adapters from given context
LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx);

// Set LoRa adapters on the context. Will only modify if the adapters currently in context are different.
LLAMA_API void llama_put_adapter_loras(struct llama_context * ctx, size_t num_adapters, struct llama_adapter_lora ** adapters, float * scales);

// Apply a loaded control vector to a llama_context, or if data is NULL, clear
// the currently loaded vector.
// n_embd should be the size of a single layer's control, and data should point
Expand Down
38 changes: 38 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,40 @@ bool llama_context::rm_adapter_lora(
return false;
}

void llama_context::put_adapter_loras(size_t num_adapters, llama_adapter_lora ** adapters, float * scales) {
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);

if (are_adapter_loras_same(num_adapters, adapters, scales)) {
return;
}

clear_adapter_lora();

for (size_t i = 0; i < num_adapters; i ++) {
if (scales[i] != 0.0f) {
set_adapter_lora(adapters[i], scales[i]);
}
}
}

bool llama_context::are_adapter_loras_same(size_t num_adapters, llama_adapter_lora ** adapters, float * scales) {
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);

if (num_adapters != loras.size()) {
return false;
}

for (size_t i = 0; i < num_adapters; i ++) {
auto it = loras.find(adapters[i]);

if (it == loras.end() || it->second != scales[i]) {
return false;
}
}

return true;
}

void llama_context::clear_adapter_lora() {
LLAMA_LOG_DEBUG("%s: call\n", __func__);

Expand Down Expand Up @@ -3243,6 +3277,10 @@ void llama_clear_adapter_lora(llama_context * ctx) {
ctx->clear_adapter_lora();
}

void llama_put_adapter_loras(llama_context * ctx, size_t num_adapters, llama_adapter_lora ** adapters, float * scales) {
ctx->put_adapter_loras(num_adapters, adapters, scales);
}

int32_t llama_apply_adapter_cvec(
llama_context * ctx,
const float * data,
Expand Down
4 changes: 4 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ struct llama_context {
bool rm_adapter_lora(
llama_adapter_lora * adapter);

void put_adapter_loras(size_t num_adapters, llama_adapter_lora ** adapters, float * scales);

bool are_adapter_loras_same(size_t num_adapters, llama_adapter_lora ** adapters, float * scales);

void clear_adapter_lora();

bool apply_adapter_cvec(
Expand Down