Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions gemma/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from __future__ import annotations
from collections.abc import Sequence
import io
import einops
from etils import epath
import jax
from jax import numpy as jnp
from kauldron import typing
import numpy as np
from PIL import Image
import tensorflow as tf

_IMAGE_MEAN = (127.5,) * 3
_IMAGE_STD = (127.5,) * 3
Expand Down Expand Up @@ -69,11 +69,17 @@ def pre_process_image(
Returns:
The pre-processed image.
"""
# all inputs are expected to have been jpeg compressed.
# TODO(eyvinec): we should remove tf dependency.
image = jnp.asarray(
tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3)
)
# Normalize image to uint8 range for JPEG encoding
image_uint8 = jnp.clip(image, 0, 255).astype(jnp.uint8)

# Encode and decode with JPEG via PIL for standardization
pil_image = Image.fromarray(np.array(image_uint8), mode='RGB')
jpeg_buffer = io.BytesIO()
pil_image.save(jpeg_buffer, format='JPEG')
jpeg_buffer.seek(0)
image = np.array(Image.open(jpeg_buffer).convert('RGB'))
image = jnp.asarray(image, dtype=jnp.float32)

image = jax.image.resize(
image,
shape=(image_height, image_width, 3),
Expand Down