File tree Expand file tree Collapse file tree 3 files changed +17
-2
lines changed
src/exo/worker/engines/image Expand file tree Collapse file tree 3 files changed +17
-2
lines changed Original file line number Diff line number Diff line change 2424 default_steps = {"low" : 10 , "medium" : 25 , "high" : 50 },
2525 num_sync_steps_factor = 0.125 , # ~3 sync steps for medium (30 steps)
2626 uses_attention_mask = True , # Qwen uses encoder_hidden_states_mask
27- guidance_scale = None , # Set to None or < 1.0 to disable CFG
27+ guidance_scale = 3.5 , # Set to None or < 1.0 to disable CFG
2828)
2929
3030# Qwen-Image-Edit uses the same architecture but different processing pipeline
4545 default_steps = {"low" : 10 , "medium" : 25 , "high" : 50 },
4646 num_sync_steps_factor = 0.125 ,
4747 uses_attention_mask = True ,
48- guidance_scale = None ,
48+ guidance_scale = 3.5 ,
4949)
Original file line number Diff line number Diff line change @@ -97,6 +97,14 @@ def reset_cache(self) -> None:
9797 """Reset the KV cache. Call at the start of a new generation."""
9898 self ._kv_cache = None
9999
100+ def set_encoder_mask (self , mask : mx .array | None ) -> None : # noqa: B027
101+ """Set the encoder hidden states mask for attention.
102+
103+ Override in subclasses that use attention masks (e.g., Qwen).
104+ Default is a no-op for models that don't use masks (e.g., Flux).
105+ """
106+ del mask # Unused in base class
107+
100108 def __call__ (
101109 self ,
102110 hidden_states : mx .array ,
Original file line number Diff line number Diff line change @@ -420,6 +420,13 @@ def _forward_pass(
420420 # Ensure wrappers are initialized (lazy - needs text_seq_len)
421421 self ._ensure_wrappers (text_seq_len , encoder_hidden_states_mask )
422422
423+ # Update masks on all joint block wrappers for this pass.
424+ # This is necessary for CFG where we run positive and negative passes
425+ # with different masks. Qwen uses masks; Flux doesn't.
426+ if self .joint_block_wrappers and encoder_hidden_states_mask is not None :
427+ for wrapper in self .joint_block_wrappers :
428+ wrapper .set_encoder_mask (encoder_hidden_states_mask )
429+
423430 scaled_latents = config .scheduler .scale_model_input (latents , t )
424431
425432 # For edit mode: concatenate with conditioning latents
You can’t perform that action at this time.
0 commit comments