Skip to content

Commit 01148d5

Browse files
committed
Enable CFG for Qwen-Image
1 parent 29191b0 commit 01148d5

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

src/exo/worker/engines/image/models/qwen/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
default_steps={"low": 10, "medium": 25, "high": 50},
2525
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
2626
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
27-
guidance_scale=None, # Set to None or < 1.0 to disable CFG
27+
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
2828
)
2929

3030
# Qwen-Image-Edit uses the same architecture but different processing pipeline
@@ -45,5 +45,5 @@
4545
default_steps={"low": 10, "medium": 25, "high": 50},
4646
num_sync_steps_factor=0.125,
4747
uses_attention_mask=True,
48-
guidance_scale=None,
48+
guidance_scale=3.5,
4949
)

src/exo/worker/engines/image/pipeline/block_wrapper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ def reset_cache(self) -> None:
9797
"""Reset the KV cache. Call at the start of a new generation."""
9898
self._kv_cache = None
9999

100+
def set_encoder_mask(self, mask: mx.array | None) -> None: # noqa: B027
101+
"""Set the encoder hidden states mask for attention.
102+
103+
Override in subclasses that use attention masks (e.g., Qwen).
104+
Default is a no-op for models that don't use masks (e.g., Flux).
105+
"""
106+
del mask # Unused in base class
107+
100108
def __call__(
101109
self,
102110
hidden_states: mx.array,

src/exo/worker/engines/image/pipeline/runner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,13 @@ def _forward_pass(
420420
# Ensure wrappers are initialized (lazy - needs text_seq_len)
421421
self._ensure_wrappers(text_seq_len, encoder_hidden_states_mask)
422422

423+
# Update masks on all joint block wrappers for this pass.
424+
# This is necessary for CFG where we run positive and negative passes
425+
# with different masks. Qwen uses masks; Flux doesn't.
426+
if self.joint_block_wrappers and encoder_hidden_states_mask is not None:
427+
for wrapper in self.joint_block_wrappers:
428+
wrapper.set_encoder_mask(encoder_hidden_states_mask)
429+
423430
scaled_latents = config.scheduler.scale_model_input(latents, t)
424431

425432
# For edit mode: concatenate with conditioning latents

0 commit comments

Comments
 (0)