|
10 | 10 | from exo.worker.engines.image.config import ImageModelConfig |
11 | 11 |
|
12 | 12 | if TYPE_CHECKING: |
13 | | - from exo.worker.engines.image.pipeline.adapter import ( |
14 | | - BlockWrapperMode, |
15 | | - JointBlockInterface, |
16 | | - SingleBlockInterface, |
| 13 | + from exo.worker.engines.image.pipeline.block_wrapper import ( |
| 14 | + JointBlockWrapper, |
| 15 | + SingleBlockWrapper, |
17 | 16 | ) |
18 | | - from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache |
19 | 17 |
|
20 | 18 |
|
21 | 19 | class PromptData(ABC): |
@@ -128,13 +126,35 @@ def _get_latent_creator(self) -> type: |
128 | 126 | ... |
129 | 127 |
|
130 | 128 | @abstractmethod |
131 | | - def get_joint_blocks(self) -> list["JointBlockInterface"]: |
132 | | - """Get the list of joint transformer blocks from the model.""" |
| 129 | + def get_joint_block_wrappers( |
| 130 | + self, |
| 131 | + text_seq_len: int, |
| 132 | + encoder_hidden_states_mask: mx.array | None = None, |
| 133 | + ) -> list["JointBlockWrapper"]: |
| 134 | + """Create wrapped joint transformer blocks with pipefusion support. |
| 135 | +
|
| 136 | + Args: |
| 137 | + text_seq_len: Number of text tokens (constant for generation) |
| 138 | + encoder_hidden_states_mask: Attention mask for text (Qwen only) |
| 139 | +
|
| 140 | + Returns: |
| 141 | + List of wrapped joint blocks ready for pipefusion |
| 142 | + """ |
133 | 143 | ... |
134 | 144 |
|
135 | 145 | @abstractmethod |
136 | | - def get_single_blocks(self) -> list["SingleBlockInterface"]: |
137 | | - """Get the list of single transformer blocks from the model.""" |
| 146 | + def get_single_block_wrappers( |
| 147 | + self, |
| 148 | + text_seq_len: int, |
| 149 | + ) -> list["SingleBlockWrapper"]: |
| 150 | + """Create wrapped single transformer blocks with pipefusion support. |
| 151 | +
|
| 152 | + Args: |
| 153 | + text_seq_len: Number of text tokens (constant for generation) |
| 154 | +
|
| 155 | + Returns: |
| 156 | + List of wrapped single blocks ready for pipefusion |
| 157 | + """ |
138 | 158 | ... |
139 | 159 |
|
140 | 160 | @abstractmethod |
@@ -285,81 +305,13 @@ def compute_rotary_embeddings( |
285 | 305 | """ |
286 | 306 | ... |
287 | 307 |
|
288 | | - @abstractmethod |
289 | | - def apply_joint_block( |
290 | | - self, |
291 | | - block: "JointBlockInterface", |
292 | | - hidden_states: mx.array, |
293 | | - encoder_hidden_states: mx.array, |
294 | | - text_embeddings: mx.array, |
295 | | - rotary_embeddings: Any, |
296 | | - kv_cache: "ImagePatchKVCache | None", |
297 | | - mode: "BlockWrapperMode", |
298 | | - text_seq_len: int, |
299 | | - patch_start: int | None = None, |
300 | | - patch_end: int | None = None, |
301 | | - encoder_hidden_states_mask: mx.array | None = None, |
302 | | - block_idx: int | None = None, |
303 | | - ) -> tuple[mx.array, mx.array]: |
304 | | - """Apply a joint transformer block. |
305 | | -
|
306 | | - Args: |
307 | | - block: The joint transformer block |
308 | | - hidden_states: Image hidden states |
309 | | - encoder_hidden_states: Text hidden states |
310 | | - text_embeddings: Conditioning embeddings |
311 | | - rotary_embeddings: Rotary position embeddings (format varies by model) |
312 | | - kv_cache: KV cache (None if not using cache) |
313 | | - mode: CACHING or PATCHED mode |
314 | | - text_seq_len: Text sequence length |
315 | | - patch_start: Start index for patched mode |
316 | | - patch_end: End index for patched mode |
317 | | - encoder_hidden_states_mask: Attention mask for text (Qwen) |
318 | | - block_idx: Block index for debugging (Qwen) |
319 | | -
|
320 | | - Returns: |
321 | | - Tuple of (encoder_hidden_states, hidden_states) |
322 | | - """ |
323 | | - ... |
324 | | - |
325 | 308 | def merge_streams( |
326 | 309 | self, |
327 | 310 | hidden_states: mx.array, |
328 | 311 | encoder_hidden_states: mx.array, |
329 | 312 | ) -> mx.array: |
330 | 313 | return mx.concatenate([encoder_hidden_states, hidden_states], axis=1) |
331 | 314 |
|
332 | | - @abstractmethod |
333 | | - def apply_single_block( |
334 | | - self, |
335 | | - block: "SingleBlockInterface", |
336 | | - hidden_states: mx.array, |
337 | | - text_embeddings: mx.array, |
338 | | - rotary_embeddings: mx.array, |
339 | | - kv_cache: "ImagePatchKVCache | None", |
340 | | - mode: "BlockWrapperMode", |
341 | | - text_seq_len: int, |
342 | | - patch_start: int | None = None, |
343 | | - patch_end: int | None = None, |
344 | | - ) -> mx.array: |
345 | | - """Apply a single transformer block. |
346 | | -
|
347 | | - Args: |
348 | | - block: The single transformer block |
349 | | - hidden_states: Concatenated [text + image] hidden states |
350 | | - text_embeddings: Conditioning embeddings |
351 | | - rotary_embeddings: Rotary position embeddings |
352 | | - kv_cache: KV cache (None if not using cache) |
353 | | - mode: CACHING or PATCHED mode |
354 | | - text_seq_len: Text sequence length |
355 | | - patch_start: Start index for patched mode |
356 | | - patch_end: End index for patched mode |
357 | | -
|
358 | | - Returns: |
359 | | - Output hidden states |
360 | | - """ |
361 | | - ... |
362 | | - |
363 | 315 | @abstractmethod |
364 | 316 | def apply_guidance( |
365 | 317 | self, |
|
0 commit comments