Skip to content

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Jan 21, 2026

cont #18953

This is based on top of the changes in #18953.

Currently, the CUDA FA implementation has certain hardcoded assumptions about the layout of the K and V tensors when MLA is involved (see #13435). The goal of this change is to make things more generic and avoid these assumptions:

  • Change the concat layout of the MLA cache. The old was [pe, kv]. The new one is [kv, pe]. This has certain implications for backends such as Vulkan and overall it's better layout in terms of memory alignment
  • Update the graph build code to pass the V tensor as a view of K. This can be used as a signal for the CUDA (and later other) backends to avoid loading extra V data during compute. (the elimination of the V component from the llama_kv_cache will be done in follow-up PR - for now it's just a redundant data)
  • Added tests to exercise the "V is view of K" path of the FA - currently these tests will still fail (for CUDA only) because we changed the layout (see the point above). This needs to be taken into account in the CUDA implementation. For more info, see the comments in the test-backend-ops.cpp (CUDA impl fixed in f07c65b)

Next PRs:

  • Do not allocate V cache when MLA is used
  • Add regular "V is view of K" FA tests in test-backend-ops and avoid hardcoded logic
  • Expand CUDA kernel coverage for quantized MLA cache and general improvement of the MLA-related logic (resolved in CUDA: fix alignment check for FA #19023)
  • Try to take advantage of "V is view of K" in other backends

@github-actions github-actions bot added model Model specific testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Jan 21, 2026
Comment on lines +785 to +789
// TODO: make this more generic by removing the notion of "MLA".
// for example "is V a view of K?" so we can skip loading it.
// V strides should be driven by V itself and avoid assumption of the data layout
const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler Let me know if this is clear. The proposal is that when V = ggml_view(K), the implementation can use this information, for example to avoid extra loads of data, etc. But technically, this is completely optional to do and the implementation can also just read from V directly.

This way, if the user code insists on passing different V data, then that should also work. In that case V won't be a view of K so it will be treated as a regular V tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should adjust the kernel selection logic in fattn.cu to only use the MMA kernel if this boolean is true. The MMA kernel has the MLA-specific optimization of re-using the K data that was previously loaded for calculation of KQ as V data when calculating VKQ. But the tile kernel simply loads the data from the K and V pointers with no re-use. So for now I would suggest that we condition the use of the MMA kernel on this boolean and use the tile kernel as a fallback.

We could in principle compile multiple template specializations but currently they would be unused for real models.

// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
Copy link
Member Author

@ggerganov ggerganov Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I completely understand the logic for reusing the K data, but to the best of my understanding this is the correct version now. Extra look is need to confirm.

To clarify, on master, the K/V data is layout like this in memory:

# K - 16 cells of K-only data
# S - 16 cells of shared data between both K and V

# a single latent head of 576 cells 
kkKKssssssssssssssssSSSSSSSSSSSSSSSS

In this PR, we change the layout so the K data is at the end of the row:

# K - 16 cells of K-only data
# S - 16 cells of shared data between both K and V

# a single latent head of 576 cells 
ssssssssssssssssSSSSSSSSSSSSSSSSkkKK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we change the memory layout on master the loop structure in the MMA kernel needs to be done differently. On master the loop for K is going forward, the one for V is going backward. But if the reusable data is at the beginning the loop for K should be backward and the loop for V should be forward.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I am not sure I understand. Is this just for performance reasons? The current implementation works correctly and passes the tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think I know what you are thinking. I updated the comment to better reflect the layout. So to clarify - the order of the shared data is the same, it's just shifted forward. And the K-only data is also ordered the same - it's just moved at the end of the buffer.

The same transform is also being applied to the Q data. So at the end, the dot products of QK and QKV are the same. Does that clarify?

Copy link
Collaborator

@JohannesGaessler JohannesGaessler Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MMA kernel uses only the K pointer which was meant as an assist for a more general implementation of deduplicating the KV cache for DeepSeek that never materialized. But it is also re-using the K data that is already in SRAM to avoid having to load it again. For that the loop over V is done in reverse because on consumer GPUs there simply isn't enough SRAM to hold the entire tile of K at once. So yes, that part is only for better performance.

// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
// note: rope must go first for in-place context shifting in build_rope_shift()
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
Copy link
Collaborator

@ngxson ngxson Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a side-effect of this change is that build_rope_shift will no longer work correctly (may need to use a view when shifting)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should fix it:

diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index fd9f97d52..2cc9efed8 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -1614,10 +1614,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
 
         ggml_tensor * k =
             ggml_view_3d(ctx, layer.k,
-                n_embd_head_k, n_head_kv, get_size()*n_stream,
+                n_embd_head_k - hparams.n_lora_kv, n_head_kv, get_size()*n_stream,
                 ggml_row_size(layer.k->type, n_embd_head_k),
                 ggml_row_size(layer.k->type, n_embd_k_gqa),
-                0);
+                ggml_row_size(layer.k->type, hparams.n_lora_kv));
 
         ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rope shift logic for PLM and minicpm3 has been broken since it didn't take into account the "nope" portion of the K embeddings. This is fixed now with 69d4fd7

@ggerganov ggerganov marked this pull request as ready for review January 22, 2026 13:21
@ggerganov ggerganov requested a review from CISC as a code owner January 22, 2026 13:21
@ggerganov
Copy link
Member Author

This should be the minimum changes needed to stabilize the CI and allow further work on the MLA code paths in libllama and the corresponding backend FA implementations. I've added some notes about the next steps.

v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);

// note: this is the currently assumed layout by the CUDA FA implementation
// however, this layout is problematic, because the offset can become very inconveniet for quantized KV types (i.e. not multiple of 16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean here. Moving the K-only data to the front does not fix any issues with memory alignment. On master the check for memory alignment causes quantized KV cache to fall back to CPU but that is a bug with the check, see #19023 .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is no longer relevant - I removed it.

For reference, the problem is that the Vulkan backend does not support views with offset % 16 != 0. At least when I tried, it asserted here:

static vk_subbuffer ggml_vk_tensor_subbuffer(
const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
vk_buffer buffer = nullptr;
size_t offset = 0;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
}
if (!buffer) {
auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
buffer = buf_ctx->dev_buffer;
offset = vk_tensor_offset(tensor) + tensor->view_offs;
}
GGML_ASSERT(buffer != nullptr);
size_t size = ggml_nbytes(tensor);
size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
// The shader must support misaligned offsets when indexing into the buffer
GGML_ASSERT(allow_misalign || misalign_bytes == 0);
offset &= ~misalign_bytes;
size += misalign_bytes;
return vk_subbuffer{buffer, offset, size};
}

But this is no longer relevant, because we change the memory layout in this PR and we no longer need to offset the view.

ggerganov and others added 2 commits January 22, 2026 15:52
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
@ggerganov ggerganov merged commit a5eaa1d into master Jan 22, 2026
76 of 78 checks passed
@ggerganov ggerganov deleted the gg/mla-improve branch January 22, 2026 20:09
ronaldmannak pushed a commit to PicoMLX/llama.cpp that referenced this pull request Jan 24, 2026
* mla : pass V as a view of K to the FA op

* cuda : adjust mla logic to new layout

* kv-cache : fix rope shift

* tests : remove comment

* cuda : fix reusable_cutoff

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
ronaldmannak pushed a commit to PicoMLX/llama.cpp that referenced this pull request Jan 24, 2026
* mla : pass V as a view of K to the FA op

* cuda : adjust mla logic to new layout

* kv-cache : fix rope shift

* tests : remove comment

* cuda : fix reusable_cutoff

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
maxious added a commit to maxious/llama.cpp that referenced this pull request Jan 27, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
maxious added a commit to maxious/llama.cpp that referenced this pull request Jan 31, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
maxious added a commit to maxious/llama.cpp that referenced this pull request Feb 1, 2026
When V is a view of K but with different head dimensions (e.g., GLM-4.7-Flash
with K=576, V=512), we cannot simply reuse K's data pointer for V.

For MLA models, the K tensor layout is [kv_lora_scaled (DV), pe (DQK-DV)],
so V data is the first DV elements of each K row.

This fix extracts the correct V data from K when DQK != DV in:
- ggml_sycl_op_flash_attn_1 (basic FA path)
- ggml_sycl_op_flash_attn_coopmat (XMX path)
- ggml_sycl_op_flash_attn_mkl (oneMKL path)

Fixes GPU memory faults and incorrect results in backend tests for
hsk=576,hsv=512 configurations.

Aligns with upstream PRs ggml-org#18953, ggml-org#18986, ggml-org#19067 that implement V-less KV cache
for MLA models like DeepSeek and GLM-4.7-Flash.

Amp-Thread-ID: https://ampcode.com/threads/T-019bf97a-9105-718e-84fb-320913c5f0c6
Co-authored-by: Amp <amp@ampcode.com>
@LostRuins
Copy link
Collaborator

LostRuins commented Feb 3, 2026

I have noticed a coherency regression when used with GLM-4-32B-0414 when using context shifting (llama_memory_seq_rm + llama_memory_seq_add) after merging this PR. Details at #19292

@ggerganov
Copy link
Member Author

Could you provide logs during the loading with the model parameters (head size, n_rot, etc)? It's not obvious why this model would break from the changes here.

@LostRuins
Copy link
Collaborator

Sure, here's a long of loading the model with vulkan

ggml_vulkan: Found 2 Vulkan devices:
ggml_vulkan: 0 = Intel(R) RaptorLake-S Mobile Graphics Controller (Intel Corporation) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 32768 | int dot: 1 | matrix cores: none
ggml_vulkan: 1 = NVIDIA GeForce RTX 4090 Laptop GPU (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
load_backend: loaded Vulkan backend from C:\Users\user\Desktop\llama-b7921-bin-win-vulkan-x64\ggml-vulkan.dll
load_backend: loaded CPU backend from C:\Users\user\Desktop\llama-b7921-bin-win-vulkan-x64\ggml-cpu-haswell.dll
build: 7921 (e9a859db3) with Clang 19.1.5 for Windows x86_64
main: llama backend init
main: load the model and apply lora adapter, if any
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
llama_params_fit_impl: projected to use 19583 MiB of device memory vs. 15278 MiB of free device memory
llama_params_fit_impl: cannot meet free memory target of 1024 MiB, need to reduce device memory by 5328 MiB
llama_params_fit_impl: context size reduced from 32768 to 4096 -> need 1708 MiB less memory in total
llama_params_fit_impl: filling dense layers back-to-front:
llama_params_fit_impl:   - Vulkan1 (NVIDIA GeForce RTX 4090 Laptop GPU): 49 layers,  14204 MiB used,   1073 MiB free
llama_params_fit: successfully fit params to free device memory
llama_params_fit: fitting params to free memory took 0.25 seconds
llama_model_load_from_file_impl: using device Vulkan1 (NVIDIA GeForce RTX 4090 Laptop GPU) (0000:01:00.0) - 15278 MiB free
llama_model_loader: loaded meta data with 44 key-value pairs and 613 tensors from D:\ExtDrive\KoboldCpp\models\GLM-4-32B-0414-Q4_K_S.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = glm4
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Glm-4-32B-0414
llama_model_loader: - kv   3:                            general.version str              = 0414
llama_model_loader: - kv   4:                           general.basename str              = Glm-4-32B-0414
llama_model_loader: - kv   5:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   6:                         general.size_label str              = 32B
llama_model_loader: - kv   7:                            general.license str              = mit
llama_model_loader: - kv   8:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv   9:                   general.base_model.count u32              = 1
llama_model_loader: - kv  10:                  general.base_model.0.name str              = GLM 4 32B 0414
llama_model_loader: - kv  11:               general.base_model.0.version str              = 0414
llama_model_loader: - kv  12:          general.base_model.0.organization str              = THUDM
llama_model_loader: - kv  13:              general.base_model.0.repo_url str              = https://huggingface.co/THUDM/GLM-4-32...
llama_model_loader: - kv  14:                               general.tags arr[str,2]       = ["unsloth", "text-generation"]
llama_model_loader: - kv  15:                          general.languages arr[str,2]       = ["zh", "en"]
llama_model_loader: - kv  16:                           glm4.block_count u32              = 61
llama_model_loader: - kv  17:                        glm4.context_length u32              = 32768
llama_model_loader: - kv  18:                      glm4.embedding_length u32              = 6144
llama_model_loader: - kv  19:                   glm4.feed_forward_length u32              = 23040
llama_model_loader: - kv  20:                  glm4.attention.head_count u32              = 48
llama_model_loader: - kv  21:               glm4.attention.head_count_kv u32              = 2
llama_model_loader: - kv  22:                        glm4.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  23:      glm4.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  24:                  glm4.attention.key_length u32              = 128
llama_model_loader: - kv  25:                glm4.attention.value_length u32              = 128
llama_model_loader: - kv  26:                  glm4.rope.dimension_count u32              = 64
llama_model_loader: - kv  27:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  28:                         tokenizer.ggml.pre str              = glm4
llama_model_loader: - kv  29:                      tokenizer.ggml.tokens arr[str,151552]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  30:                  tokenizer.ggml.token_type arr[i32,151552]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  31:                      tokenizer.ggml.merges arr[str,318088]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  32:                tokenizer.ggml.eos_token_id u32              = 151336
llama_model_loader: - kv  33:            tokenizer.ggml.padding_token_id u32              = 151330
llama_model_loader: - kv  34:                tokenizer.ggml.eot_token_id u32              = 151336
llama_model_loader: - kv  35:            tokenizer.ggml.unknown_token_id u32              = 151329
llama_model_loader: - kv  36:                tokenizer.ggml.bos_token_id u32              = 151329
llama_model_loader: - kv  37:                    tokenizer.chat_template str              = [gMASK]<sop>\n{%- if tools -%}\n<|syste...
llama_model_loader: - kv  38:               general.quantization_version u32              = 2
llama_model_loader: - kv  39:                          general.file_type u32              = 14
llama_model_loader: - kv  40:                      quantize.imatrix.file str              = GLM-4-32B-0414-GGUF/imatrix_unsloth.dat
llama_model_loader: - kv  41:                   quantize.imatrix.dataset str              = unsloth_calibration_GLM-4-32B-0414.txt
llama_model_loader: - kv  42:             quantize.imatrix.entries_count u32              = 366
llama_model_loader: - kv  43:              quantize.imatrix.chunks_count u32              = 724
llama_model_loader: - type  f32:  245 tensors
llama_model_loader: - type q4_K:  356 tensors
llama_model_loader: - type q5_K:   11 tensors
llama_model_loader: - type q6_K:    1 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Small
print_info: file size   = 17.40 GiB (4.59 BPW)
load: 0 unused tokens
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: printing all EOG tokens:
load:   - 151329 ('<|endoftext|>')
load:   - 151336 ('<|user|>')
load: special tokens cache size = 14
load: token to piece cache size = 0.9710 MB
print_info: arch                  = glm4
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 32768
print_info: n_embd                = 6144
print_info: n_embd_inp            = 6144
print_info: n_layer               = 61
print_info: n_head                = 48
print_info: n_head_kv             = 2
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 128
print_info: n_embd_head_v         = 128
print_info: n_gqa                 = 24
print_info: n_embd_k_gqa          = 256
print_info: n_embd_v_gqa          = 256
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-05
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: n_ff                  = 23040
print_info: n_expert              = 0
print_info: n_expert_used         = 0
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = 0
print_info: rope type             = 0
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 32768
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: model type            = 32B
print_info: model params          = 32.57 B
print_info: general.name          = Glm-4-32B-0414
print_info: vocab type            = BPE
print_info: n_vocab               = 151552
print_info: n_merges              = 318088
print_info: BOS token             = 151329 '<|endoftext|>'
print_info: EOS token             = 151336 '<|user|>'
print_info: EOT token             = 151336 '<|user|>'
print_info: UNK token             = 151329 '<|endoftext|>'
print_info: PAD token             = 151330 '[MASK]'
print_info: LF token              = 198 'Ċ'
print_info: EOG token             = 151329 '<|endoftext|>'
print_info: EOG token             = 151336 '<|user|>'
print_info: max token length      = 1024
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 48 repeating layers to GPU
load_tensors: offloaded 49/62 layers to GPU
load_tensors:   CPU_Mapped model buffer size =  4129.59 MiB
load_tensors:      Vulkan1 model buffer size = 13692.96 MiB
................................................................................................
common_init_result: added <|endoftext|> logit bias = -inf
common_init_result: added <|user|> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_seq     = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (4096) < n_ctx_train (32768) -- the full capacity of the model will not be utilized
llama_context: Vulkan_Host  output buffer size =     0.58 MiB
llama_kv_cache:        CPU KV buffer size =    52.00 MiB
llama_kv_cache:    Vulkan1 KV buffer size =   192.00 MiB
llama_kv_cache: size =  244.00 MiB (  4096 cells,  61 layers,  1/1 seqs), K (f16):  122.00 MiB, V (f16):  122.00 MiB
sched_reserve: reserving ...
sched_reserve: Flash Attention was auto, set to enabled
sched_reserve:    Vulkan1 compute buffer size =   320.03 MiB
sched_reserve: Vulkan_Host compute buffer size =    32.01 MiB
sched_reserve: graph nodes  = 2081
sched_reserve: graph splits = 158 (with bs=512), 2 (with bs=1)
sched_reserve: reserve took 289.36 ms, sched copies = 1

To clarify: the inference itself is fine. The regression only happens after a KV shift is applied, like so

//extract the unwanted tokens out from context and KV from trimstart

int diff = found - trimstart;
llama_memory_seq_rm(llama_get_memory(ctx), 0, trimstart, trimstart + diff);
llama_memory_seq_add(llama_get_memory(ctx), 0, trimstart + diff, -1, -diff);

if no shifting is performed, there inference both before and after this PR is normal. If shifting is performed, there is a large degradation observed after this specific PR.

@LostRuins
Copy link
Collaborator

Also I suspect it's not just this old model but also GLM-4.7-Flash and possibly others affected, but as GLM-4.7-Flash was going through many other separate unrelated fixes I didn't want to use that as the model.

@ggerganov
Copy link
Member Author

Does it work correctly with CPU-only run?

@LostRuins
Copy link
Collaborator

LostRuins commented Feb 3, 2026

Surprisingly, Pure CPU takes forever, but it works correctly!

Shifting produces the expected results both before and after this PR with pure CPU. Which feels even stranger.

I did a few more tests:

Pure CPU = Shifting works fine. Super, super slow PP. Fully coherent.
Vulkan with FA and 0 layers offloaded = Shifting works fine, Output is fully coherent. Total VRAM usage at 0 layers is 1.1GB.
Vulkan with FA and 10/62 layers offloaded = Mild incoherence. Example of the mild incoherence:

It sparkled with a sense of fear and fear. He noticed a large crowd gathered near the edge of the forest, huddled together, casting long, dancing shadows that danced away from the flickering torches that lined the forest's edge. 

Vulkan with FA and 50/62 layers offloaded = Very incoherent. Example of extreme incoherence:

When he reached the middle of the clearing field that stretched out in front of him began to look like it was part of the town map had been pushed slightly to the east now, and now he would see the figure-shapedI-shaped man walked to the ground where the lights were coming out of the ground, which was a short spursky, but he could not make himself

Trying without FA, same result, coherence is related to how many layers are offloaded. At 50 layers:

the smell of the town turned into the main road ended up to the west of the town for a long time, making sure that he was no longer behind the other man in the town's watch, which was inky, checking the time to check if he could get some water from the fountain near the spiggetreenged the way he had been there for a while

@ggerganov
Copy link
Member Author

Seems like an issue in the Vulkan backend - track #19296 for more info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants