Skip to content

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Jan 24, 2026

cont #18986

Support V-less KV cache. This is useful for MLA models such as DeepSeek and GLM 4.7 Flash where we store combined latent data represented by the K cache. Results in almost x2 less memory for the KV cache.

Also:

  • Add llama_hparams::is_mla()
  • Add llama_hparams::n_embd_head_k_mla()
  • Add llama_hparams::n_embd_head_v_mla()
  • Rename llama_hparams::get_n_embd_out() -> llama_hparams::n_embd_out()
  • Add class llm_graph_input_attn_k - similar to class llm_graph_input_attn_kv, but only K data

@github-actions github-actions bot added model Model specific Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 24, 2026
@John-Dekka
Copy link

John-Dekka commented Jan 24, 2026

The commit computes is_mla from the model hparams (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0) and when true, the kv-cache will skip allocating V tensors and the graph uses the new k-only input path. That decision is driven by the model/hparams at load time.

GLM-4.7-Flash config.json does not include the n_embd_head_k_mla / n_embd_head_v_mla fields.

Was hoping to squeeze my GLM model a bit more. 😿


Nevermind.

deepseek2.attention.key_length_mla
deepseek2.attention.value_length_mla

@ggerganov
Copy link
Member Author

GLM-4.7-Flash config.json does not include the n_embd_head_k_mla / n_embd_head_v_mla fields.

@John-Dekka These are llama.cpp-specific parameters, they don't have to be present in the config.json. This patch applies to GLM-4.7-Flash.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUDA changes are correct, the changes in the llama.cpp user code seem correct to me though I am not as familiar with that part of the codebase.

@eapache
Copy link

eapache commented Jan 24, 2026

Will the —fit algorithm pick up the changes in memory requirements here, or does it need to be adjusted as well to expect the smaller KV cache for these models?

@JohannesGaessler
Copy link
Collaborator

llama_kv_cache has a member std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;. As long as the KV cache is allocated in ctxs_bufs it should work correctly with -fit regardless of the specific tensors.

@jacekpoplawski
Copy link
Contributor

very cool

master:

llama_params_fit_impl: projected to use 57339 MiB of device memory vs. 71537 MiB of free device memory
llama_kv_cache: size = 19525.56 MiB (200192 cells,  47 layers,  1/1 seqs), K (f16): 10337.06 MiB, V (f16): 9188.50 MiB

PR:

llama_params_fit_impl: projected to use 48150 MiB of device memory vs. 71537 MiB of free device memory
llama_kv_cache: size = 10337.06 MiB (200192 cells,  47 layers,  1/1 seqs), K (f16): 10337.06 MiB, V (f16):    0.00 MiB

@ggerganov ggerganov force-pushed the gg/kv-cache-support-no-v branch from c843f3a to 6d7ce2e Compare January 25, 2026 08:10
Comment on lines +1961 to +1967
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to add LLM_ARCH_DEEPSEEK2 here in case we suspect similar numerical issues with GLM 4.7 Flash - something too keep in mind. cc @jeffbolznv @JohannesGaessler

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC @ngxson tried this during our PR and it made no difference in his testing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At that time, the wrong gating function was used, so can't conclude based on this. Plus this is somewhat backend-specific - e.g. it's not a problem for Metal since we always accumulate in F32.

@ggerganov ggerganov merged commit d9c6ce4 into master Jan 25, 2026
78 of 81 checks passed
@ggerganov ggerganov deleted the gg/kv-cache-support-no-v branch January 25, 2026 18:02
maxious added a commit to maxious/llama.cpp that referenced this pull request Jan 27, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
maxious added a commit to maxious/llama.cpp that referenced this pull request Jan 31, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
maxious added a commit to maxious/llama.cpp that referenced this pull request Feb 1, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants