diff --git a/gemma/multimodal/image.py b/gemma/multimodal/image.py index ce4182e0..3c338e16 100644 --- a/gemma/multimodal/image.py +++ b/gemma/multimodal/image.py @@ -16,6 +16,7 @@ from __future__ import annotations from collections.abc import Sequence +import io import einops from etils import epath import jax @@ -23,7 +24,6 @@ from kauldron import typing import numpy as np from PIL import Image -import tensorflow as tf _IMAGE_MEAN = (127.5,) * 3 _IMAGE_STD = (127.5,) * 3 @@ -69,11 +69,17 @@ def pre_process_image( Returns: The pre-processed image. """ - # all inputs are expected to have been jpeg compressed. - # TODO(eyvinec): we should remove tf dependency. - image = jnp.asarray( - tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3) - ) + # Normalize image to uint8 range for JPEG encoding + image_uint8 = jnp.clip(image, 0, 255).astype(jnp.uint8) + + # Encode and decode with JPEG via PIL for standardization + pil_image = Image.fromarray(np.array(image_uint8), mode='RGB') + jpeg_buffer = io.BytesIO() + pil_image.save(jpeg_buffer, format='JPEG') + jpeg_buffer.seek(0) + image = np.array(Image.open(jpeg_buffer).convert('RGB')) + image = jnp.asarray(image, dtype=jnp.float32) + image = jax.image.resize( image, shape=(image_height, image_width, 3),