Skip to content

Commit 73911be

Browse files
leejetSkutteOleg
authored andcommitted
add support for flux2 klein 4b
1 parent 4e29a3b commit 73911be

File tree

5 files changed

+33
-13
lines changed

5 files changed

+33
-13
lines changed

conditioner.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner {
16141614
bool enable_vision = false)
16151615
: version(version) {
16161616
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
1617-
if (sd_version_is_flux2(version)) {
1617+
if (version == VERSION_FLUX2) {
16181618
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
1619-
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
1619+
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
16201620
arch = LLM::LLMArch::QWEN3;
16211621
}
16221622
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
@@ -1771,7 +1771,7 @@ struct LLMEmbedder : public Conditioner {
17711771
prompt_attn_range.second = static_cast<int>(prompt.size());
17721772

17731773
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1774-
} else if (sd_version_is_flux2(version)) {
1774+
} else if (version == VERSION_FLUX2) {
17751775
prompt_template_encode_start_idx = 0;
17761776
out_layers = {10, 20, 30};
17771777

@@ -1793,17 +1793,17 @@ struct LLMEmbedder : public Conditioner {
17931793
prompt_attn_range.second = static_cast<int>(prompt.size());
17941794

17951795
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1796-
} else if (sd_version_is_flux2(version)) {
1796+
} else if (version == VERSION_FLUX2_KLEIN) {
17971797
prompt_template_encode_start_idx = 0;
1798-
out_layers = {10, 20, 30};
1798+
out_layers = {9, 18, 27};
17991799

1800-
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
1800+
prompt = "<|im_start|>user\n";
18011801

18021802
prompt_attn_range.first = static_cast<int>(prompt.size());
18031803
prompt += conditioner_params.text;
18041804
prompt_attn_range.second = static_cast<int>(prompt.size());
18051805

1806-
prompt += "[/INST]";
1806+
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
18071807
} else if (version == VERSION_OVIS_IMAGE) {
18081808
prompt_template_encode_start_idx = 28;
18091809
max_length = prompt_template_encode_start_idx + 256;

flux.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,10 +1291,16 @@ namespace Flux {
12911291
flux_params.context_in_dim = 2048;
12921292
flux_params.vec_in_dim = 0;
12931293
} else if (sd_version_is_flux2(version)) {
1294-
flux_params.context_in_dim = 15360;
1294+
if (version == VERSION_FLUX2_KLEIN) {
1295+
flux_params.context_in_dim = 7680;
1296+
flux_params.hidden_size = 3072;
1297+
flux_params.num_heads = 24;
1298+
} else {
1299+
flux_params.context_in_dim = 15360;
1300+
flux_params.hidden_size = 6144;
1301+
flux_params.num_heads = 48;
1302+
}
12951303
flux_params.in_channels = 128;
1296-
flux_params.hidden_size = 6144;
1297-
flux_params.num_heads = 48;
12981304
flux_params.patch_size = 1;
12991305
flux_params.out_channels = 128;
13001306
flux_params.mlp_ratio = 3.f;

model.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,8 @@ SDVersion ModelLoader::get_sd_version() {
10341034

10351035
bool is_xl = false;
10361036
bool is_flux = false;
1037+
bool is_flux2 = false;
1038+
bool has_single_block_47 = false;
10371039
bool is_wan = false;
10381040
int64_t patch_embedding_channels = 0;
10391041
bool has_img_emb = false;
@@ -1055,7 +1057,10 @@ SDVersion ModelLoader::get_sd_version() {
10551057
return VERSION_QWEN_IMAGE;
10561058
}
10571059
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
1058-
return VERSION_FLUX2;
1060+
is_flux2 = true;
1061+
}
1062+
if (tensor_storage.name.find("single_blocks.47.linear1.weight") != std::string::npos) {
1063+
has_single_block_47 = true;
10591064
}
10601065
if (tensor_storage.name.find("model.diffusion_model.double_blocks.0.img_mlp.gate_proj.weight") != std::string::npos) {
10611066
return VERSION_OVIS_IMAGE;
@@ -1138,7 +1143,7 @@ SDVersion ModelLoader::get_sd_version() {
11381143
return VERSION_SDXL;
11391144
}
11401145

1141-
if (is_flux) {
1146+
if (is_flux && !is_flux2) {
11421147
if (input_block_weight.ne[0] == 384) {
11431148
return VERSION_FLUX_FILL;
11441149
}
@@ -1151,6 +1156,13 @@ SDVersion ModelLoader::get_sd_version() {
11511156
return VERSION_FLUX;
11521157
}
11531158

1159+
if (is_flux2) {
1160+
if (has_single_block_47) {
1161+
return VERSION_FLUX2;
1162+
}
1163+
return VERSION_FLUX2_KLEIN;
1164+
}
1165+
11541166
if (token_embedding_weight.ne[0] == 768) {
11551167
if (is_inpaint) {
11561168
return VERSION_SD1_INPAINT;

model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum SDVersion {
4545
VERSION_WAN2_2_TI2V,
4646
VERSION_QWEN_IMAGE,
4747
VERSION_FLUX2,
48+
VERSION_FLUX2_KLEIN,
4849
VERSION_Z_IMAGE,
4950
VERSION_OVIS_IMAGE,
5051
VERSION_COUNT,
@@ -100,7 +101,7 @@ static inline bool sd_version_is_flux(SDVersion version) {
100101
}
101102

102103
static inline bool sd_version_is_flux2(SDVersion version) {
103-
if (version == VERSION_FLUX2) {
104+
if (version == VERSION_FLUX2 || version == VERSION_FLUX2_KLEIN) {
104105
return true;
105106
}
106107
return false;

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ const char* model_version_to_str[] = {
7373
"Wan 2.2 TI2V",
7474
"Qwen Image",
7575
"Flux.2",
76+
"Flux.2 klein"
7677
"Z-Image",
7778
"Ovis Image",
7879
};

0 commit comments

Comments
 (0)