Skip to content

Commit a0a78da

Browse files
committed
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # docs/ops.md # docs/ops/SYCL.csv # ggml/src/ggml-sycl/element_wise.cpp # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp # ggml/src/ggml-webgpu/ggml-webgpu.cpp # pyproject.toml # requirements/requirements-convert_legacy_llama.txt # src/CMakeLists.txt # src/llama-vocab.cpp # tests/test-backend-ops.cpp
2 parents 9cf2119 + 34ba7b5 commit a0a78da

34 files changed

+2368
-177
lines changed

common/speculative.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,42 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
805805
return it->second;
806806
}
807807

808+
bool common_speculative_is_compat(llama_context * ctx_tgt) {
809+
auto * mem = llama_get_memory(ctx_tgt);
810+
if (mem == nullptr) {
811+
return false;
812+
}
813+
814+
bool res = true;
815+
816+
llama_memory_clear(mem, true);
817+
818+
// eval 2 tokens to check if the context is compatible
819+
std::vector<llama_token> tmp;
820+
tmp.push_back(0);
821+
tmp.push_back(0);
822+
823+
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
824+
if (ret != 0) {
825+
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
826+
res = false;
827+
goto done;
828+
}
829+
830+
// try to remove the last tokens
831+
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
832+
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
833+
res = false;
834+
goto done;
835+
}
836+
837+
done:
838+
llama_memory_clear(mem, true);
839+
llama_synchronize(ctx_tgt);
840+
841+
return res;
842+
}
843+
808844
// initialization of the speculative decoding system
809845
//
810846
common_speculative * common_speculative_init(

common/speculative.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
1414
// convert type to string
1515
std::string common_speculative_type_to_str(enum common_speculative_type type);
1616

17+
// check if the llama_context is compatible for speculative decoding
18+
// note: clears the memory of the context
19+
bool common_speculative_is_compat(llama_context * ctx_tgt);
20+
1721
common_speculative * common_speculative_init(
1822
common_params_speculative & params,
1923
llama_context * ctx_tgt);

convert_hf_to_gguf.py

Lines changed: 351 additions & 3 deletions
Large diffs are not rendered by default.

ggml/src/ggml-metal/ggml-metal-context.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
394394
[encoder endEncoding];
395395

396396
ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
397-
ggml_metal_event_record(ctx_src, ev_cpy);
397+
ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
398398

399399
[cmd_buf commit];
400400

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5285,6 +5285,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
52855285
// scan the blocks of the mask that are not masked
52865286
// 0 - masked (i.e. full of -INF, skip)
52875287
// 1 - not masked (i.e. at least one element of the mask is not -INF)
5288+
// 2 - all zero
52885289
kernel void kernel_flash_attn_ext_blk(
52895290
constant ggml_metal_kargs_flash_attn_ext_blk & args,
52905291
device const char * mask,
@@ -5306,27 +5307,29 @@ kernel void kernel_flash_attn_ext_blk(
53065307

53075308
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
53085309

5309-
// fast route
5310-
if (res == 0) {
5311-
if (simd_max(*mask_src) > -MAXHALF/2) {
5312-
res = 1;
5313-
}
5314-
}
5315-
53165310
// detailed check of the elements of the block
53175311
if ((C > NW || Q > 1) && res == 0) {
5318-
half m = -MAXHALF;
5312+
half mmin = MAXHALF;
5313+
half mmax = -MAXHALF;
53195314

53205315
FOR_UNROLL (short j = 0; j < Q; ++j) {
53215316
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
5322-
m = max(m, mask_src[ii*NW]);
5317+
mmin = min(mmin, mask_src[ii*NW]);
5318+
mmax = max(mmax, mask_src[ii*NW]);
53235319
}
53245320

53255321
mask_src += args.nb31/2;
53265322
}
53275323

5328-
if (simd_max(m) > -MAXHALF/2) {
5329-
res = 1;
5324+
mmin = simd_min(mmin);
5325+
mmax = simd_max(mmax);
5326+
5327+
if (mmax > -MAXHALF) {
5328+
if (mmin == 0.0 && mmax == 0.0) {
5329+
res = 2;
5330+
} else {
5331+
res = 1;
5332+
}
53305333
}
53315334
}
53325335

@@ -5568,26 +5571,36 @@ void kernel_flash_attn_ext_impl(
55685571
ic = 0;
55695572
}
55705573

5574+
char blk_cur = 1;
5575+
55715576
// read the mask into shared mem
55725577
if (FC_flash_attn_ext_has_mask) {
5573-
if (blk[ic0] == 0) {
5578+
blk_cur = blk[ic0];
5579+
5580+
if (blk_cur == 0) {
55745581
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
55755582
pm2[jj] += NW;
55765583
}
55775584

55785585
continue;
55795586
}
55805587

5581-
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5582-
const short j = jj*NSG + sgitg;
5588+
if (blk_cur == 1) {
5589+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5590+
const short j = jj*NSG + sgitg;
55835591

5584-
if (FC_flash_attn_ext_bc_mask) {
5585-
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
5586-
} else {
5587-
sm2[j*SH + tiisg] = pm2[jj][tiisg];
5588-
}
5592+
if (FC_flash_attn_ext_bc_mask) {
5593+
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
5594+
} else {
5595+
sm2[j*SH + tiisg] = pm2[jj][tiisg];
5596+
}
55895597

5590-
pm2[jj] += NW;
5598+
pm2[jj] += NW;
5599+
}
5600+
} else if (blk_cur == 2) {
5601+
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5602+
pm2[jj] += NW;
5603+
}
55915604
}
55925605

55935606
#if 0
@@ -5752,10 +5765,12 @@ void kernel_flash_attn_ext_impl(
57525765
}
57535766

57545767
// mqk = mqk + slope*mask
5755-
if (FC_flash_attn_ext_has_bias) {
5756-
s2 += s2_t(sm2[j*SH + tiisg])*slope;
5757-
} else {
5758-
s2 += s2_t(sm2[j*SH + tiisg]);
5768+
if (blk_cur != 2) {
5769+
if (FC_flash_attn_ext_has_bias) {
5770+
s2 += s2_t(sm2[j*SH + tiisg])*slope;
5771+
} else {
5772+
s2 += s2_t(sm2[j*SH + tiisg]);
5773+
}
57595774
}
57605775

57615776
M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -418,19 +418,19 @@ enum FaCodePath {
418418
};
419419

420420
struct vk_fa_pipeline_state {
421-
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt)
422-
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {}
421+
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
422+
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
423423

424424
uint32_t HSK, HSV;
425425
bool small_rows, small_cache;
426426
FaCodePath path;
427427
bool aligned;
428428
bool f32acc;
429-
bool use_mask_opt;
429+
uint32_t flags;
430430

431431
bool operator<(const vk_fa_pipeline_state &b) const {
432-
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) <
433-
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt);
432+
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
433+
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
434434
}
435435
};
436436

@@ -3209,7 +3209,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
32093209
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
32103210
};
32113211

3212-
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector<uint32_t> {
3212+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
32133213
// For large number of rows, 128 invocations seems to work best.
32143214
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
32153215
// can't use 256 for D==80.
@@ -3241,7 +3241,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
32413241
// AMD prefers loading K directly from global memory
32423242
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
32433243

3244-
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt};
3244+
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
32453245
};
32463246

32473247
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@@ -3253,19 +3253,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
32533253
FaCodePath path = fa.first.path; \
32543254
bool aligned = fa.first.aligned; \
32553255
bool f32acc = fa.first.f32acc; \
3256-
bool use_mask_opt = fa.first.use_mask_opt; \
3256+
uint32_t flags = fa.first.flags; \
32573257
if (path == FAPATH) { \
32583258
if (aligned) { \
32593259
if (f32acc) { \
3260-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3260+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32613261
} else { \
3262-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3262+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32633263
} \
32643264
} else { \
32653265
if (f32acc) { \
3266-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3266+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32673267
} else { \
3268-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3268+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32693269
} \
32703270
} \
32713271
} \
@@ -8633,10 +8633,26 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
86338633

86348634
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
86358635

8636+
float scale = 1.0f;
8637+
float max_bias = 0.0f;
8638+
float logit_softcap = 0.0f;
8639+
8640+
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8641+
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8642+
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8643+
8644+
if (logit_softcap != 0) {
8645+
scale /= logit_softcap;
8646+
}
8647+
86368648
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
86378649
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
86388650

8639-
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
8651+
uint32_t flags = (use_mask_opt ? 1 : 0) |
8652+
(mask != nullptr ? 2 : 0) |
8653+
(logit_softcap != 0 ? 4 : 0);
8654+
8655+
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
86408656

86418657
vk_pipeline pipeline = nullptr;
86428658

@@ -8716,18 +8732,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
87168732
}
87178733
}
87188734

8719-
float scale = 1.0f;
8720-
float max_bias = 0.0f;
8721-
float logit_softcap = 0.0f;
8722-
8723-
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8724-
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8725-
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8726-
8727-
if (logit_softcap != 0) {
8728-
scale /= logit_softcap;
8729-
}
8730-
87318735
const uint32_t n_head_kv = neq2;
87328736
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
87338737
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -8741,7 +8745,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
87418745
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
87428746
vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
87438747

8744-
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
8748+
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
87458749

87468750
if (use_mask_opt)
87478751
{

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void main() {
127127
continue;
128128
}
129129
// Only load if the block is not all zeros
130-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
130+
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
131131
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
132132

133133
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
@@ -181,15 +181,15 @@ void main() {
181181
}
182182
}
183183

184-
if (p.logit_softcap != 0.0f) {
184+
if (LOGIT_SOFTCAP) {
185185
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
186186
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
187187
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
188188
}
189189
}
190190
}
191191

192-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
192+
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
193193
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
194194
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
195195
float mvf = masksh[c * cols_per_iter + col_tid][r];

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ layout (constant_id = 5) const uint32_t Clamp = 0;
1010
layout (constant_id = 6) const uint32_t D_split = 16;
1111
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
1212
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
13-
layout (constant_id = 9) const bool USE_MASK_OPT = false;
13+
layout (constant_id = 9) const uint32_t Flags = 0;
14+
15+
const bool USE_MASK_OPT = (Flags & 1) != 0;
16+
const bool MASK_ENABLE = (Flags & 2) != 0;
17+
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
1418

1519
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
1620
const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -60,7 +64,6 @@ layout (push_constant) uniform parameter {
6064
} p;
6165

6266
#define SINK_ENABLE_BIT (1<<24)
63-
#define MASK_ENABLE_BIT (1<<16)
6467
#define N_LOG2_MASK 0xFFFF
6568

6669
layout (binding = 4) readonly buffer S {float data_s[];};
@@ -237,3 +240,7 @@ void init_indices()
237240
// and breaking the alignment detection.
238241
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
239242
}
243+
244+
// Bias applied to softmax to stay in fp16 range.
245+
// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
246+
const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ void main() {
160160
mask_cache[idx] = f16vec4(0);
161161
}
162162

163-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
163+
if (MASK_ENABLE) {
164164

165165
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
166166
mask_opt_idx = j / 16;
@@ -303,7 +303,7 @@ void main() {
303303
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
304304
barrier();
305305

306-
if (p.logit_softcap != 0.0f) {
306+
if (LOGIT_SOFTCAP) {
307307
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
308308
uint32_t c = (idx + tid) / (Br / 4);
309309
uint32_t r = (idx + tid) % (Br / 4);
@@ -314,7 +314,7 @@ void main() {
314314
barrier();
315315
}
316316

317-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
317+
if (MASK_ENABLE) {
318318
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
319319
uint32_t c = (idx + tid) / (Br / 4);
320320
uint32_t r = (idx + tid) % (Br / 4);

0 commit comments

Comments
 (0)