Skip to content

Commit ad5703f

Browse files
committed
Use transformer block wrapper classes
1 parent ad6f2bf commit ad5703f

File tree

10 files changed

+1196
-1303
lines changed

10 files changed

+1196
-1303
lines changed

src/exo/worker/engines/image/models/base.py

Lines changed: 29 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@
1010
from exo.worker.engines.image.config import ImageModelConfig
1111

1212
if TYPE_CHECKING:
13-
from exo.worker.engines.image.pipeline.adapter import (
14-
BlockWrapperMode,
15-
JointBlockInterface,
16-
SingleBlockInterface,
13+
from exo.worker.engines.image.pipeline.block_wrapper import (
14+
JointBlockWrapper,
15+
SingleBlockWrapper,
1716
)
18-
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
1917

2018

2119
class PromptData(ABC):
@@ -128,13 +126,35 @@ def _get_latent_creator(self) -> type:
128126
...
129127

130128
@abstractmethod
131-
def get_joint_blocks(self) -> list["JointBlockInterface"]:
132-
"""Get the list of joint transformer blocks from the model."""
129+
def get_joint_block_wrappers(
130+
self,
131+
text_seq_len: int,
132+
encoder_hidden_states_mask: mx.array | None = None,
133+
) -> list["JointBlockWrapper"]:
134+
"""Create wrapped joint transformer blocks with pipefusion support.
135+
136+
Args:
137+
text_seq_len: Number of text tokens (constant for generation)
138+
encoder_hidden_states_mask: Attention mask for text (Qwen only)
139+
140+
Returns:
141+
List of wrapped joint blocks ready for pipefusion
142+
"""
133143
...
134144

135145
@abstractmethod
136-
def get_single_blocks(self) -> list["SingleBlockInterface"]:
137-
"""Get the list of single transformer blocks from the model."""
146+
def get_single_block_wrappers(
147+
self,
148+
text_seq_len: int,
149+
) -> list["SingleBlockWrapper"]:
150+
"""Create wrapped single transformer blocks with pipefusion support.
151+
152+
Args:
153+
text_seq_len: Number of text tokens (constant for generation)
154+
155+
Returns:
156+
List of wrapped single blocks ready for pipefusion
157+
"""
138158
...
139159

140160
@abstractmethod
@@ -285,81 +305,13 @@ def compute_rotary_embeddings(
285305
"""
286306
...
287307

288-
@abstractmethod
289-
def apply_joint_block(
290-
self,
291-
block: "JointBlockInterface",
292-
hidden_states: mx.array,
293-
encoder_hidden_states: mx.array,
294-
text_embeddings: mx.array,
295-
rotary_embeddings: Any,
296-
kv_cache: "ImagePatchKVCache | None",
297-
mode: "BlockWrapperMode",
298-
text_seq_len: int,
299-
patch_start: int | None = None,
300-
patch_end: int | None = None,
301-
encoder_hidden_states_mask: mx.array | None = None,
302-
block_idx: int | None = None,
303-
) -> tuple[mx.array, mx.array]:
304-
"""Apply a joint transformer block.
305-
306-
Args:
307-
block: The joint transformer block
308-
hidden_states: Image hidden states
309-
encoder_hidden_states: Text hidden states
310-
text_embeddings: Conditioning embeddings
311-
rotary_embeddings: Rotary position embeddings (format varies by model)
312-
kv_cache: KV cache (None if not using cache)
313-
mode: CACHING or PATCHED mode
314-
text_seq_len: Text sequence length
315-
patch_start: Start index for patched mode
316-
patch_end: End index for patched mode
317-
encoder_hidden_states_mask: Attention mask for text (Qwen)
318-
block_idx: Block index for debugging (Qwen)
319-
320-
Returns:
321-
Tuple of (encoder_hidden_states, hidden_states)
322-
"""
323-
...
324-
325308
def merge_streams(
326309
self,
327310
hidden_states: mx.array,
328311
encoder_hidden_states: mx.array,
329312
) -> mx.array:
330313
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
331314

332-
@abstractmethod
333-
def apply_single_block(
334-
self,
335-
block: "SingleBlockInterface",
336-
hidden_states: mx.array,
337-
text_embeddings: mx.array,
338-
rotary_embeddings: mx.array,
339-
kv_cache: "ImagePatchKVCache | None",
340-
mode: "BlockWrapperMode",
341-
text_seq_len: int,
342-
patch_start: int | None = None,
343-
patch_end: int | None = None,
344-
) -> mx.array:
345-
"""Apply a single transformer block.
346-
347-
Args:
348-
block: The single transformer block
349-
hidden_states: Concatenated [text + image] hidden states
350-
text_embeddings: Conditioning embeddings
351-
rotary_embeddings: Rotary position embeddings
352-
kv_cache: KV cache (None if not using cache)
353-
mode: CACHING or PATCHED mode
354-
text_seq_len: Text sequence length
355-
patch_start: Start index for patched mode
356-
patch_end: End index for patched mode
357-
358-
Returns:
359-
Output hidden states
360-
"""
361-
...
362-
363315
@abstractmethod
364316
def apply_guidance(
365317
self,

0 commit comments

Comments
 (0)