Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ void launch_fattn(
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);

const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
}

const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);

const int cc = ggml_cuda_info().devices[device].cc;

Expand Down
8 changes: 4 additions & 4 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}

const uint32_t n_embd_out = model.hparams.get_n_embd_out();
const uint32_t n_embd_out = model.hparams.n_embd_out();
return embd + j*n_embd_out;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
Expand Down Expand Up @@ -1279,7 +1279,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
const uint32_t n_embd_out = hparams.get_n_embd_out();
const uint32_t n_embd_out = hparams.n_embd_out();

GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
Expand Down Expand Up @@ -1688,7 +1688,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
const uint32_t n_embd_out = hparams.get_n_embd_out();
const uint32_t n_embd_out = hparams.n_embd_out();
float * embd_out = embd + n_outputs_prev*n_embd_out;

if (n_outputs) {
Expand Down Expand Up @@ -1821,7 +1821,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba

const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd_out = hparams.get_n_embd_out();
const auto n_embd_out = hparams.n_embd_out();

bool has_logits = true;
bool has_embd = cparams.embeddings;
Expand Down
117 changes: 111 additions & 6 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
return res;
}

void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);

mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}

bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);

this->mctx = mctx;

bool res = true;

res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;

res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;

return res;
}

void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
Expand Down Expand Up @@ -1596,11 +1617,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
v = ggml_transpose(ctx0, v);
}

// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
if (v_mla) {
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
}

// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
if (k->type == GGML_TYPE_F32) {
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
Expand Down Expand Up @@ -1823,9 +1839,11 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
ggml_tensor * v_mla, // TODO: remove
float kq_scale,
int il) const {
GGML_ASSERT(v_mla == nullptr);

// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
Expand Down Expand Up @@ -1868,6 +1886,93 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}

static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_context * mctx_cur) {

auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);

{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");

const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;

inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);

inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}

return inp;
}

llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);

auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);

return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_k * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, v_cur);
ggml_build_forward_expand(gf, k_cur);

const auto * mctx_cur = inp->mctx;

// store to KV cache
{
const auto & k_idxs = inp->get_k_idxs();

ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}

const auto & kq_mask = inp->get_kq_mask();

ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);

if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
Comment on lines +1961 to +1967
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to add LLM_ARCH_DEEPSEEK2 here in case we suspect similar numerical issues with GLM 4.7 Flash - something too keep in mind. cc @jeffbolznv @JohannesGaessler

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC @ngxson tried this during our PR and it made no difference in his testing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At that time, the wrong gating function was used, so can't conclude based on this. Plus this is somewhat backend-specific - e.g. it's not a problem for Metal since we always accumulate in F32.


if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}

return cur;
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
Expand Down
48 changes: 48 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,39 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
const llama_kv_cache_context * mctx;
};

// V-less input for the KV cache
// ref: https://github.com/ggml-org/llama.cpp/pull/19067
class llm_graph_input_attn_k : public llm_graph_input_i {
public:
llm_graph_input_attn_k(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_k() = default;

void set_input(const llama_ubatch * ubatch) override;

bool can_reuse(const llm_graph_params & params) override;

ggml_tensor * get_k_idxs() const { return self_k_idxs; }

ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }

ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

const llama_hparams hparams;
const llama_cparams cparams;

const llama_kv_cache_context * mctx;
};

class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_iswa(
Expand Down Expand Up @@ -833,6 +866,21 @@ struct llm_graph_context {
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
float kq_scale,
int il) const;

llm_graph_input_attn_k * build_attn_inp_k() const;

ggml_tensor * build_attn(
llm_graph_input_attn_k * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
Expand Down
19 changes: 17 additions & 2 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ uint32_t llama_hparams::n_embd_inp() const {
return n_embd_inp;
}

uint32_t llama_hparams::get_n_embd_out() const {
return n_embd_out > 0 ? n_embd_out : n_embd;
uint32_t llama_hparams::n_embd_out() const {
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
}

uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
Expand Down Expand Up @@ -175,6 +175,21 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error");
}

bool llama_hparams::is_mla() const {
assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) ||
(n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0));

return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0;
}

uint32_t llama_hparams::n_embd_head_k_mla() const {
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
}

uint32_t llama_hparams::n_embd_head_v_mla() const {
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
}

bool llama_hparams::has_kv(uint32_t il) const {
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
Expand Down
14 changes: 10 additions & 4 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ struct llama_hparams {
uint32_t n_rel_attn_bkts = 0;

// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
uint32_t n_embd_head_k_mla = 0;
uint32_t n_embd_head_v_mla = 0;
uint32_t n_embd_head_k_mla_impl = 0;
uint32_t n_embd_head_v_mla_impl = 0;

// for WavTokenizer
struct llama_hparams_posnet posnet;
Expand Down Expand Up @@ -164,7 +164,7 @@ struct llama_hparams {
uint32_t n_cls_out = 1;

// output embedding dimension (0 = use n_embd)
uint32_t n_embd_out = 0;
uint32_t n_embd_out_impl = 0;

// llama4 smallthinker
uint32_t n_moe_layer_step = 0;
Expand Down Expand Up @@ -239,7 +239,7 @@ struct llama_hparams {
uint32_t n_embd_inp() const;

// dimension of output embeddings
uint32_t get_n_embd_out() const;
uint32_t n_embd_out() const;

// dimension of key embeddings across all k-v heads
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
Expand Down Expand Up @@ -269,6 +269,12 @@ struct llama_hparams {

bool is_swa(uint32_t il) const;

// note: currently only support if either all or none of the layers are MLA
bool is_mla() const;

uint32_t n_embd_head_k_mla() const;
uint32_t n_embd_head_v_mla() const;

bool has_kv(uint32_t il) const;

// number of layers for which has_kv() returns true
Expand Down
Loading
Loading