99from tqdm import tqdm
1010
1111from exo .shared .types .worker .shards import PipelineShardMetadata
12+ from exo .utils .memory import log_memory
1213from exo .worker .engines .image .config import ImageModelConfig
1314from exo .worker .engines .image .pipeline .adapter import (
1415 BlockWrapperMode ,
@@ -194,8 +195,14 @@ def generate_image(
194195 GeneratedImage result
195196 """
196197 runtime_config = RuntimeConfig (settings , self .adapter .model .model_config )
198+
197199 latents = self .adapter .create_latents (seed , runtime_config )
200+ mx .eval (latents )
201+ log_memory ("generate_image: after create_latents" )
202+
198203 prompt_data = self .adapter .encode_prompt (prompt )
204+ mx .eval (prompt_data )
205+ log_memory ("generate_image: after encode_prompt" )
199206
200207 latents = self ._run_diffusion_loop (
201208 latents = latents ,
@@ -204,9 +211,14 @@ def generate_image(
204211 seed = seed ,
205212 prompt = prompt ,
206213 )
214+ mx .eval (latents )
215+ log_memory ("generate_image: after diffusion_loop" )
207216
208217 if self .is_last_stage :
209- return self .adapter .decode_latents (latents , runtime_config , seed , prompt )
218+ result = self .adapter .decode_latents (latents , runtime_config , seed , prompt )
219+ mx .eval (result )
220+ log_memory ("generate_image: after decode_latents" )
221+ return result
210222
211223 def _run_diffusion_loop (
212224 self ,
@@ -260,6 +272,7 @@ def _run_diffusion_loop(
260272 )
261273
262274 mx .eval (latents )
275+ log_memory ("after diffusion step" )
263276
264277 except KeyboardInterrupt : # noqa: PERF203
265278 Callbacks .interruption (
@@ -310,16 +323,24 @@ def _forward_pass(
310323 if config is None :
311324 raise ValueError ("config must be provided in kwargs" )
312325 scaled_latents = config .scheduler .scale_model_input (latents , t )
326+ mx .eval (scaled_latents )
327+ log_memory ("after scaling model input" )
313328
314329 hidden_states , encoder_hidden_states = self .adapter .compute_embeddings (
315330 scaled_latents , prompt_embeds
316331 )
332+ mx .eval (hidden_states , encoder_hidden_states )
333+ log_memory ("after computing embeddings" )
317334 text_embeddings = self .adapter .compute_text_embeddings (
318335 t , config , pooled_prompt_embeds , hidden_states = hidden_states
319336 )
337+ mx .eval (text_embeddings )
338+ log_memory ("after computing text embeddings" )
320339 rotary_embeddings = self .adapter .compute_rotary_embeddings (
321340 prompt_embeds , config , ** kwargs
322341 )
342+ mx .eval (rotary_embeddings )
343+ log_memory ("after computing rotary embeddings" )
323344
324345 text_seq_len = prompt_embeds .shape [1 ]
325346
@@ -337,6 +358,8 @@ def _forward_pass(
337358 ** kwargs ,
338359 )
339360
361+ mx .eval (hidden_states , encoder_hidden_states )
362+ log_memory ("after joint blocks" )
340363 # Merge streams
341364 if self .joint_block_wrappers :
342365 hidden_states = self .adapter .merge_streams (
@@ -354,9 +377,15 @@ def _forward_pass(
354377 mode = BlockWrapperMode .CACHING ,
355378 )
356379
380+ mx .eval (hidden_states )
381+ log_memory ("after single blocks" )
382+
357383 # Extract image portion and project
358384 hidden_states = hidden_states [:, text_seq_len :, ...]
359- return self .adapter .final_projection (hidden_states , text_embeddings )
385+ hidden_states = self .adapter .final_projection (hidden_states , text_embeddings )
386+ mx .eval (hidden_states )
387+ log_memory ("after final projection" )
388+ return hidden_states
360389
361390 def _diffusion_step (
362391 self ,
@@ -371,7 +400,10 @@ def _diffusion_step(
371400 configuration and current timestep.
372401 """
373402 if self .group is None :
374- return self ._single_node_step (t , config , latents , prompt_data )
403+ latents = self ._single_node_step (t , config , latents , prompt_data )
404+ mx .eval (latents )
405+ log_memory ("single node step complete" )
406+ return latents
375407 elif t < self .num_sync_steps :
376408 return self ._sync_pipeline (
377409 t ,
@@ -435,6 +467,7 @@ def _single_node_step(
435467 kwargs ,
436468 )
437469
470+ log_memory ("scheduler step" )
438471 return config .scheduler .step (model_output = noise , timestep = t , sample = latents )
439472
440473 def _initialize_kv_caches (
0 commit comments