Skip to content

Commit 0783aaf

Browse files
committed
Measure memory
1 parent a86bb97 commit 0783aaf

File tree

5 files changed

+75
-6
lines changed

5 files changed

+75
-6
lines changed

src/exo/utils/memory.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from time import sleep
2+
3+
import mlx.core as mx
4+
import psutil
5+
6+
from exo.worker.runner.bootstrap import logger
7+
8+
9+
def log_memory(label: str) -> None:
10+
"""Log current and peak GPU memory usage."""
11+
active_mb = mx.get_active_memory() / (1024 * 1024)
12+
peak_mb = mx.get_peak_memory() / (1024 * 1024)
13+
res_mem = psutil.Process().memory_info().rss / (1024 * 1024)
14+
logger.info(
15+
f"[MEMORY] {label}: active={active_mb:.1f}MB, peak={peak_mb:.1f}MB, res={res_mem:.1f}MB"
16+
)
17+
sleep(3)

src/exo/worker/engines/image/distributed_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from exo.shared.types.worker.instances import BoundInstance
99
from exo.shared.types.worker.shards import PipelineShardMetadata
10+
from exo.utils.memory import log_memory
1011
from exo.worker.download.download_utils import build_model_path
1112
from exo.worker.engines.image.config import ImageModelConfig
1213
from exo.worker.engines.image.models import (
@@ -43,8 +44,11 @@ def __init__(
4344
quantize: int | None = None,
4445
):
4546
# Get model config and create adapter (adapter owns the model)
47+
mx.metal.reset_peak_memory()
48+
4649
config = get_config_for_model(model_id)
4750
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
51+
log_memory("After model load (adapter created)")
4852

4953
if group is not None:
5054
adapter.slice_transformer_blocks(
@@ -53,6 +57,7 @@ def __init__(
5357
total_joint_blocks=config.joint_block_count,
5458
total_single_blocks=config.single_block_count,
5559
)
60+
log_memory("After block slicing")
5661

5762
# Create diffusion runner (handles both single-node and distributed modes)
5863
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
@@ -63,20 +68,23 @@ def __init__(
6368
shard_metadata=shard_metadata,
6469
num_sync_steps=num_sync_steps,
6570
)
71+
log_memory("After DiffusionRunner creation")
6672

6773
if group is not None:
6874
logger.info("Initialized distributed diffusion runner")
6975

7076
mx.eval(adapter.model.parameters())
77+
log_memory("After mx.eval(parameters)")
7178

72-
# TODO(ciaran): Do we need this?
7379
mx.eval(adapter.model)
80+
log_memory("After mx.eval(model)")
7481

7582
# Synchronize processes before generation to avoid timeout
7683
mx_barrier(group)
7784
logger.info(f"Transformer sharded for rank {group.rank()}")
7885
else:
7986
logger.info("Single-node initialization")
87+
log_memory("Single-node init complete")
8088

8189
object.__setattr__(self, "_config", config)
8290
object.__setattr__(self, "_adapter", adapter)

src/exo/worker/engines/image/generate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import io
22
from typing import Generator, Literal
33

4+
import mlx.core as mx
45
from PIL import Image
56

67
from exo.shared.types.api import ImageGenerationTaskParams
78
from exo.shared.types.worker.runner_response import ImageGenerationResponse
9+
from exo.utils.memory import log_memory
810
from exo.worker.engines.image.base import ImageGenerator
11+
from exo.worker.runner.bootstrap import logger
912

1013

1114
def parse_size(size_str: str | None) -> tuple[int, int]:
@@ -26,13 +29,16 @@ def parse_size(size_str: str | None) -> tuple[int, int]:
2629

2730

2831
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
29-
return model.generate(
32+
log_memory("Before warmup generation")
33+
result = model.generate(
3034
prompt="Warmup",
3135
height=256,
3236
width=256,
3337
quality="low",
3438
seed=2,
3539
)
40+
log_memory("After warmup generation")
41+
return result
3642

3743

3844
def generate_image(

src/exo/worker/engines/image/models/qwen/adapter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
1717

18+
from exo.utils.memory import log_memory
1819
from exo.worker.engines.image.config import ImageModelConfig
1920
from exo.worker.engines.image.models.base import BaseModelAdapter
2021
from exo.worker.engines.image.pipeline.adapter import (
@@ -307,7 +308,11 @@ def final_projection(
307308
) -> mx.array:
308309
"""Apply final normalization and projection."""
309310
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
310-
return self._transformer.proj_out(hidden_states)
311+
mx.eval(hidden_states)
312+
log_memory("after norm out")
313+
hidden_states = self._transformer.proj_out(hidden_states)
314+
mx.eval(hidden_states)
315+
log_memory("after proj_out")
311316

312317
def get_joint_blocks(self) -> list[JointBlockInterface]:
313318
"""Return all 60 transformer blocks."""

src/exo/worker/engines/image/pipeline/runner.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tqdm import tqdm
1010

1111
from exo.shared.types.worker.shards import PipelineShardMetadata
12+
from exo.utils.memory import log_memory
1213
from exo.worker.engines.image.config import ImageModelConfig
1314
from exo.worker.engines.image.pipeline.adapter import (
1415
BlockWrapperMode,
@@ -194,8 +195,14 @@ def generate_image(
194195
GeneratedImage result
195196
"""
196197
runtime_config = RuntimeConfig(settings, self.adapter.model.model_config)
198+
197199
latents = self.adapter.create_latents(seed, runtime_config)
200+
mx.eval(latents)
201+
log_memory("generate_image: after create_latents")
202+
198203
prompt_data = self.adapter.encode_prompt(prompt)
204+
mx.eval(prompt_data)
205+
log_memory("generate_image: after encode_prompt")
199206

200207
latents = self._run_diffusion_loop(
201208
latents=latents,
@@ -204,9 +211,14 @@ def generate_image(
204211
seed=seed,
205212
prompt=prompt,
206213
)
214+
mx.eval(latents)
215+
log_memory("generate_image: after diffusion_loop")
207216

208217
if self.is_last_stage:
209-
return self.adapter.decode_latents(latents, runtime_config, seed, prompt)
218+
result = self.adapter.decode_latents(latents, runtime_config, seed, prompt)
219+
mx.eval(result)
220+
log_memory("generate_image: after decode_latents")
221+
return result
210222

211223
def _run_diffusion_loop(
212224
self,
@@ -260,6 +272,7 @@ def _run_diffusion_loop(
260272
)
261273

262274
mx.eval(latents)
275+
log_memory("after diffusion step")
263276

264277
except KeyboardInterrupt: # noqa: PERF203
265278
Callbacks.interruption(
@@ -310,16 +323,24 @@ def _forward_pass(
310323
if config is None:
311324
raise ValueError("config must be provided in kwargs")
312325
scaled_latents = config.scheduler.scale_model_input(latents, t)
326+
mx.eval(scaled_latents)
327+
log_memory("after scaling model input")
313328

314329
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
315330
scaled_latents, prompt_embeds
316331
)
332+
mx.eval(hidden_states, encoder_hidden_states)
333+
log_memory("after computing embeddings")
317334
text_embeddings = self.adapter.compute_text_embeddings(
318335
t, config, pooled_prompt_embeds, hidden_states=hidden_states
319336
)
337+
mx.eval(text_embeddings)
338+
log_memory("after computing text embeddings")
320339
rotary_embeddings = self.adapter.compute_rotary_embeddings(
321340
prompt_embeds, config, **kwargs
322341
)
342+
mx.eval(rotary_embeddings)
343+
log_memory("after computing rotary embeddings")
323344

324345
text_seq_len = prompt_embeds.shape[1]
325346

@@ -337,6 +358,8 @@ def _forward_pass(
337358
**kwargs,
338359
)
339360

361+
mx.eval(hidden_states, encoder_hidden_states)
362+
log_memory("after joint blocks")
340363
# Merge streams
341364
if self.joint_block_wrappers:
342365
hidden_states = self.adapter.merge_streams(
@@ -354,9 +377,15 @@ def _forward_pass(
354377
mode=BlockWrapperMode.CACHING,
355378
)
356379

380+
mx.eval(hidden_states)
381+
log_memory("after single blocks")
382+
357383
# Extract image portion and project
358384
hidden_states = hidden_states[:, text_seq_len:, ...]
359-
return self.adapter.final_projection(hidden_states, text_embeddings)
385+
hidden_states = self.adapter.final_projection(hidden_states, text_embeddings)
386+
mx.eval(hidden_states)
387+
log_memory("after final projection")
388+
return hidden_states
360389

361390
def _diffusion_step(
362391
self,
@@ -371,7 +400,10 @@ def _diffusion_step(
371400
configuration and current timestep.
372401
"""
373402
if self.group is None:
374-
return self._single_node_step(t, config, latents, prompt_data)
403+
latents = self._single_node_step(t, config, latents, prompt_data)
404+
mx.eval(latents)
405+
log_memory("single node step complete")
406+
return latents
375407
elif t < self.num_sync_steps:
376408
return self._sync_pipeline(
377409
t,
@@ -435,6 +467,7 @@ def _single_node_step(
435467
kwargs,
436468
)
437469

470+
log_memory("scheduler step")
438471
return config.scheduler.step(model_output=noise, timestep=t, sample=latents)
439472

440473
def _initialize_kv_caches(

0 commit comments

Comments
 (0)