@@ -418,19 +418,19 @@ enum FaCodePath {
418418};
419419
420420struct vk_fa_pipeline_state {
421- vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt )
422- : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt ) {}
421+ vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags )
422+ : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags ) {}
423423
424424 uint32_t HSK, HSV;
425425 bool small_rows, small_cache;
426426 FaCodePath path;
427427 bool aligned;
428428 bool f32acc;
429- bool use_mask_opt ;
429+ uint32_t flags ;
430430
431431 bool operator<(const vk_fa_pipeline_state &b) const {
432- return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt ) <
433- std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt );
432+ return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags ) <
433+ std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags );
434434 }
435435};
436436
@@ -3209,7 +3209,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
32093209 return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
32103210 };
32113211
3212- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt ) -> std::vector<uint32_t> {
3212+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags ) -> std::vector<uint32_t> {
32133213 // For large number of rows, 128 invocations seems to work best.
32143214 // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
32153215 // can't use 256 for D==80.
@@ -3241,7 +3241,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
32413241 // AMD prefers loading K directly from global memory
32423242 const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
32433243
3244- return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt };
3244+ return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags };
32453245 };
32463246
32473247#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@@ -3253,19 +3253,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
32533253 FaCodePath path = fa.first.path; \
32543254 bool aligned = fa.first.aligned; \
32553255 bool f32acc = fa.first.f32acc; \
3256- bool use_mask_opt = fa.first.use_mask_opt ; \
3256+ uint32_t flags = fa.first.flags ; \
32573257 if (path == FAPATH) { \
32583258 if (aligned) { \
32593259 if (f32acc) { \
3260- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt ), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3260+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags ), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32613261 } else { \
3262- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt ), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3262+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags ), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32633263 } \
32643264 } else { \
32653265 if (f32acc) { \
3266- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt ), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3266+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags ), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32673267 } else { \
3268- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt ), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3268+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags ), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
32693269 } \
32703270 } \
32713271 } \
@@ -8633,10 +8633,26 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
86338633
86348634 bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
86358635
8636+ float scale = 1.0f;
8637+ float max_bias = 0.0f;
8638+ float logit_softcap = 0.0f;
8639+
8640+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8641+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8642+ memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8643+
8644+ if (logit_softcap != 0) {
8645+ scale /= logit_softcap;
8646+ }
8647+
86368648 // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
86378649 bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
86388650
8639- vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
8651+ uint32_t flags = (use_mask_opt ? 1 : 0) |
8652+ (mask != nullptr ? 2 : 0) |
8653+ (logit_softcap != 0 ? 4 : 0);
8654+
8655+ vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
86408656
86418657 vk_pipeline pipeline = nullptr;
86428658
@@ -8716,18 +8732,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
87168732 }
87178733 }
87188734
8719- float scale = 1.0f;
8720- float max_bias = 0.0f;
8721- float logit_softcap = 0.0f;
8722-
8723- memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8724- memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8725- memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8726-
8727- if (logit_softcap != 0) {
8728- scale /= logit_softcap;
8729- }
8730-
87318735 const uint32_t n_head_kv = neq2;
87328736 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
87338737 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -8741,7 +8745,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
87418745 vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
87428746 vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
87438747
8744- uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
8748+ uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
87458749
87468750 if (use_mask_opt)
87478751 {
0 commit comments