Conversation
photoroman
left a comment
There was a problem hiding this comment.
I did a brief initial pass. This looks really great! Thank you!
The main missing part for me are more docs. I think we should add docstrings to all public classes and methods. Also we should add some docs and examples in docs/source/en/api/pipelines/mirage.md, probably.
| if vae_type == "flux": | ||
| config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated/transformer/config.json" | ||
| elif vae_type == "dc-ae": | ||
| config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated/transformer/config.json" |
There was a problem hiding this comment.
Should we change these hardcoded paths?
| #!/usr/bin/env python3 | ||
| """ | ||
| Script to convert Mirage checkpoint from original codebase to diffusers format. | ||
| """ |
There was a problem hiding this comment.
Should we release this script or put it into the computer vision repo and release the converted checkpoints?
There was a problem hiding this comment.
I think the idea is to keep this only internally, right?
There was a problem hiding this comment.
Looking at other scripts in this folder, it seems that most companies actually put these kinds of scripts in diffusers.
| mapping = {} | ||
|
|
||
| # RMSNorm: scale -> weight | ||
| for i in range(16): # 16 layers |
There was a problem hiding this comment.
Use a MIRAGE_NUM_LAYERS: int = 16 constant at the top?
There was a problem hiding this comment.
I change it to come from a config instead.
| if vae_type == "flux": | ||
| ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated" | ||
| else: # dc-ae | ||
| ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated" |
There was a problem hiding this comment.
Should we change these hardcoded paths?
There was a problem hiding this comment.
I removed all dependency to previous ref pipeline, this was a mistake.
| return mapping | ||
|
|
||
|
|
||
| def convert_checkpoint_parameters(old_state_dict: dict) -> dict: |
There was a problem hiding this comment.
Probably good to use more specific types, like Dict[str, str]. Import Dict from typing, because the diffusers library supports Python 3.8 and built-in types with generics, e.g. dict[str, str] are only supported since Python 3.9.
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: |
There was a problem hiding this comment.
- Use more readable variable names:
batch_size,height,width. - Add docstrings for all public methods and classes.
| assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" | ||
| assert attention_mask.shape[-1] == l_txt, ( | ||
| f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" | ||
| ) |
There was a problem hiding this comment.
Should these be checked as if conditions and raised as ValueError? Usually asserts are to catch programming errors not for input validation.
| vae ([`AutoencoderKL`] or [`AutoencoderDC`]): | ||
| Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | ||
| Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression). | ||
| """ |
There was a problem hiding this comment.
I think in addition to this, we need to add docs about Mirage in docs/source/en/api/pipelines/mirage.md. Have a look at the section "Adding a new pipeline/scheduler" in docs/README.md.
|
|
||
| # 0. Default height and width to transformer config | ||
| height = height or 256 | ||
| width = width or 256 |
There was a problem hiding this comment.
nit: put 256 into a constant.
There was a problem hiding this comment.
I specify it in the checkpoint config now and changed it for a constant when not specified
| ) | ||
|
|
||
| # Convert back to image format | ||
| from ...models.transformers.transformer_mirage import seq2img |
There was a problem hiding this comment.
Any reason not to import this at the top?
There was a problem hiding this comment.
I've been becoming quite keen of doing in-code imports for cases like this one, where seq2img is only used once in the whole file :D
There was a problem hiding this comment.
I moved it to the top.
| """ | ||
|
|
||
|
|
||
| class MiragePipeline( |
There was a problem hiding this comment.
I remember Eliot started a thread about maybe changing the name of Mirage to something else, since there's already a ML model called like that. Did that end up on something?
| pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") | ||
| pipe.to("cuda") | ||
|
|
||
| prompt = "A digital painting of a rusty, vintage tram on a sandy beach" |
There was a problem hiding this comment.
We should use a more viral example prompt. Once we chose a viral name for the model, we can match the prompt to the name ;) Maybe we have to activate the Photoroom marketing team for this one.
| from diffusers import MiragePipeline | ||
|
|
||
| # Load pipeline - VAE and text encoder will be loaded from HuggingFace | ||
| pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") |
There was a problem hiding this comment.
I guess we'll be able to store the checkpoint on Hugging Face as well, right? If yes, we should not forget to update the paths here to the official one, to make this truly copy-paste and run.
photoroman
left a comment
There was a problem hiding this comment.
Had a quick look. Great! I trust Claude renamed everything correctly.
| pipe.to("cuda") | ||
|
|
||
| prompt = "A digital painting of a rusty, vintage tram on a sandy beach" | ||
| prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light" |
Mirage pipeline
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.