Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5285,6 +5285,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
// scan the blocks of the mask that are not masked
// 0 - masked (i.e. full of -INF, skip)
// 1 - not masked (i.e. at least one element of the mask is not -INF)
// 2 - all zero
kernel void kernel_flash_attn_ext_blk(
constant ggml_metal_kargs_flash_attn_ext_blk & args,
device const char * mask,
Expand All @@ -5306,27 +5307,29 @@ kernel void kernel_flash_attn_ext_blk(

device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;

// fast route
if (res == 0) {
if (simd_max(*mask_src) > -MAXHALF/2) {
res = 1;
}
}

// detailed check of the elements of the block
if ((C > NW || Q > 1) && res == 0) {
half m = -MAXHALF;
half mmin = MAXHALF;
half mmax = -MAXHALF;

FOR_UNROLL (short j = 0; j < Q; ++j) {
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
m = max(m, mask_src[ii*NW]);
mmin = min(mmin, mask_src[ii*NW]);
mmax = max(mmax, mask_src[ii*NW]);
}

mask_src += args.nb31/2;
}

if (simd_max(m) > -MAXHALF/2) {
res = 1;
mmin = simd_min(mmin);
mmax = simd_max(mmax);

if (mmax > -MAXHALF) {
if (mmin == 0.0 && mmax == 0.0) {
res = 2;
} else {
res = 1;
}
}
}

Expand Down Expand Up @@ -5568,26 +5571,36 @@ void kernel_flash_attn_ext_impl(
ic = 0;
}

char blk_cur = 1;

// read the mask into shared mem
if (FC_flash_attn_ext_has_mask) {
if (blk[ic0] == 0) {
blk_cur = blk[ic0];

if (blk_cur == 0) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
pm2[jj] += NW;
}

continue;
}

FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
if (blk_cur == 1) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;

if (FC_flash_attn_ext_bc_mask) {
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
} else {
sm2[j*SH + tiisg] = pm2[jj][tiisg];
}
if (FC_flash_attn_ext_bc_mask) {
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
} else {
sm2[j*SH + tiisg] = pm2[jj][tiisg];
}

pm2[jj] += NW;
pm2[jj] += NW;
}
} else if (blk_cur == 2) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
pm2[jj] += NW;
}
}

#if 0
Expand Down Expand Up @@ -5752,10 +5765,12 @@ void kernel_flash_attn_ext_impl(
}

// mqk = mqk + slope*mask
if (FC_flash_attn_ext_has_bias) {
s2 += s2_t(sm2[j*SH + tiisg])*slope;
} else {
s2 += s2_t(sm2[j*SH + tiisg]);
if (blk_cur != 2) {
if (FC_flash_attn_ext_has_bias) {
s2 += s2_t(sm2[j*SH + tiisg])*slope;
} else {
s2 += s2_t(sm2[j*SH + tiisg]);
}
}

M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
Expand Down
Loading