diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c09a54e6614..51f3fca55ef 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5285,6 +5285,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E // scan the blocks of the mask that are not masked // 0 - masked (i.e. full of -INF, skip) // 1 - not masked (i.e. at least one element of the mask is not -INF) +// 2 - all zero kernel void kernel_flash_attn_ext_blk( constant ggml_metal_kargs_flash_attn_ext_blk & args, device const char * mask, @@ -5306,27 +5307,29 @@ kernel void kernel_flash_attn_ext_blk( device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; - // fast route - if (res == 0) { - if (simd_max(*mask_src) > -MAXHALF/2) { - res = 1; - } - } - // detailed check of the elements of the block if ((C > NW || Q > 1) && res == 0) { - half m = -MAXHALF; + half mmin = MAXHALF; + half mmax = -MAXHALF; FOR_UNROLL (short j = 0; j < Q; ++j) { FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { - m = max(m, mask_src[ii*NW]); + mmin = min(mmin, mask_src[ii*NW]); + mmax = max(mmax, mask_src[ii*NW]); } mask_src += args.nb31/2; } - if (simd_max(m) > -MAXHALF/2) { - res = 1; + mmin = simd_min(mmin); + mmax = simd_max(mmax); + + if (mmax > -MAXHALF) { + if (mmin == 0.0 && mmax == 0.0) { + res = 2; + } else { + res = 1; + } } } @@ -5568,9 +5571,13 @@ void kernel_flash_attn_ext_impl( ic = 0; } + char blk_cur = 1; + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { - if (blk[ic0] == 0) { + blk_cur = blk[ic0]; + + if (blk_cur == 0) { FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { pm2[jj] += NW; } @@ -5578,16 +5585,22 @@ void kernel_flash_attn_ext_impl( continue; } - FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { - const short j = jj*NSG + sgitg; + if (blk_cur == 1) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - if (FC_flash_attn_ext_bc_mask) { - sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); - } else { - sm2[j*SH + tiisg] = pm2[jj][tiisg]; - } + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } - pm2[jj] += NW; + pm2[jj] += NW; + } + } else if (blk_cur == 2) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } } #if 0 @@ -5752,10 +5765,12 @@ void kernel_flash_attn_ext_impl( } // mqk = mqk + slope*mask - if (FC_flash_attn_ext_has_bias) { - s2 += s2_t(sm2[j*SH + tiisg])*slope; - } else { - s2 += s2_t(sm2[j*SH + tiisg]); + if (blk_cur != 2) { + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } } M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));