Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions examples/speculative-eagle/speculative-eagle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,29 @@ static bool cb_get_hidden(struct ggml_tensor * tensor, bool ask, void * user_dat
return true;
}

int64_t start_time;

static bool cb_get_latency(struct ggml_tensor * tensor, bool ask, [[maybe_unused]] void * user_data) { //latency profiling callback function -ym-
if (ask) {
start_time = ggml_time_us();
return true;
}

int64_t end_time = ggml_time_us();
int64_t latency = end_time - start_time;
LOG_DBG("[[Latency for tensor]] '%s' (%s): %ld us ==> (%d)\n", tensor->name, ggml_op_name(tensor->op), latency, (int)ggml_backend_buffer_is_host(tensor->buffer));
ggml_tensor * src_tensor = tensor->src[0];
LOG_DBG("[[Latency for tensor]] [%d, %d, %d, %d]\n", (int)src_tensor->ne[0], (int)src_tensor->ne[1], (int)src_tensor->ne[2], (int)src_tensor->ne[3]);
LOG_DBG("[[Latency for tensor]] [%d, %d, %d, %d]\n", (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2], (int)tensor->ne[3]);


return true;
}

struct seq_draft { //각 드래프트 시퀀스(트리의 브랜치)의 상태를 저장하는 구조체 -ym-
bool active = false;
bool drafting = false;
bool skip = false;
bool active = false; //verification 단계에서 시퀀스가 활성화되었는지 여부 -ym-
bool drafting = false; //drafting 단계에서 시퀀스가 활성화되었는지 여부 -ym-
bool skip = false; //drafting 단계에서 이 시퀀스를 건너뛸지 여부 -ym-

int i_batch_dft = 0; //드래프트 모델의 배치에서 이 시퀀스의 마지막 토큰 인덱스 -ym-
std::vector<int> i_batch_tgt; //타겟 모델의 배치에서 이 시퀀스에 해당하는 토큰들의 인덱스 -ym-
Expand Down Expand Up @@ -115,6 +134,7 @@ int main(int argc, char ** argv) {
}

params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
//params.cb_eval = cb_get_latency;
common_init_result llama_init_dft = common_init_from_params(params);

model_dft = llama_init_dft.model.get();
Expand Down
2 changes: 1 addition & 1 deletion src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_EMBD_FC, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_EMBD_FC, {LLM_TENSOR_LAYER_INPUT_EAGLE, GGML_OP_MUL_MAT}},
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ enum llm_tensor_layer {
LLM_TENSOR_LAYER_INPUT,
LLM_TENSOR_LAYER_REPEATING,
LLM_TENSOR_LAYER_OUTPUT,
LLM_TENSOR_LAYER_INPUT_EAGLE,
};

struct LLM_KV {
Expand Down
5 changes: 4 additions & 1 deletion src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}

// sanity checks
if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT || info.layer == LLM_TENSOR_LAYER_INPUT_EAGLE) {
if (tn.bid != -1) {
GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());
}
Expand All @@ -1719,6 +1719,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_TENSOR_LAYER_REPEATING:
buft_list = pimpl->dev_layer.at(tn.bid).buft_list;
break;
case LLM_TENSOR_LAYER_INPUT_EAGLE:
buft_list = pimpl->dev_output.buft_list; // EAGLE input layer is the same as output layer
break;
default:
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
}
Expand Down
Loading