Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ fbcode_target(_kind = python_library,
"//executorch/examples/models/toy_model:toy_model", # @manual
"//executorch/examples/models/wav2letter:w2l_model", # @manual
"//executorch/examples/models/llama3_2_vision:multimodal_lib", # @manual
"//executorch/examples/models/gemma4:gemma4", # @manual
"//executorch/examples/models/gemma3:gemma3", # @manual
"//executorch/examples/models/qwen2_5:qwen2_5", # @manual
"//executorch/examples/models/qwen3:qwen3", # @manual
Expand Down
24 changes: 24 additions & 0 deletions examples/models/gemma4/BUCK
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

fbcode_target(_kind = runtime.python_library,
name = "gemma4",
srcs = [
"__init__.py",
"convert_weights.py",
],
base_module = "executorch.examples.models.gemma4",
resources = {
"config/e2b_config.json": "config/e2b_config.json",
"config/e4b_config.json": "config/e4b_config.json",
},
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:llama2_model",
"//executorch/examples/models:checkpoint",
"fbsource//third-party/pypi/safetensors:safetensors",
],
visibility = ["PUBLIC"],
)
52 changes: 52 additions & 0 deletions examples/models/gemma4/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Summary

This example adds native ExecuTorch text-only export support for Google's Gemma 4 `E2B` and `E4B` models through the existing Llama-style export path.

The current scope is the decoder-only text model. It does not include the multimodal image or audio towers from the full Gemma 4 release.

# Supported models

- `google/gemma-4-E2B`
- `google/gemma-4-E4B`

# Exporting the model

The exporter can download and convert the Hugging Face checkpoint automatically, or you can point it at a pre-converted local checkpoint.

## Export Gemma 4 E2B

```bash
PYTHONPATH=.:.. python examples/models/llama/export_llama.py \
--model gemma4_e2b \
--params examples/models/gemma4/config/e2b_config.json \
--dtype-override bf16 \
--output-dir ./gemma4_e2b_out
```

## Export Gemma 4 E4B

```bash
PYTHONPATH=.:.. python examples/models/llama/export_llama.py \
--model gemma4_e4b \
--params examples/models/gemma4/config/e4b_config.json \
--dtype-override bf16 \
--output-dir ./gemma4_e4b_out
```

## Export with KV cache and custom SDPA

```bash
PYTHONPATH=.:.. python examples/models/llama/export_llama.py \
--model gemma4_e4b \
--params examples/models/gemma4/config/e4b_config.json \
--dtype-override bf16 \
--use_kv_cache \
--use_sdpa_with_kv_cache \
--disable_dynamic_shape \
--output-dir ./gemma4_e4b_kv_out
```

# Notes

- The Gemma 4 exporter uses the native ExecuTorch text runtime and the local `convert_weights.py` checkpoint conversion flow.
- In local source-tree workflows, `PYTHONPATH=.:..` makes both `examples.*` and `executorch.*` imports work consistently.
19 changes: 19 additions & 0 deletions examples/models/gemma4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.gemma4.convert_weights import convert_weights

__all__ = ["Gemma4Model", "convert_weights"]


def __getattr__(name):
if name == "Gemma4Model":
from executorch.examples.models.llama.model import Llama2Model

class Gemma4Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)

globals()["Gemma4Model"] = Gemma4Model
return Gemma4Model
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
40 changes: 40 additions & 0 deletions examples/models/gemma4/config/e2b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"dim": 1536,
"hidden_dim": 6144,
"n_layers": 35,
"n_heads": 8,
"n_kv_heads": 1,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"vocab_size_per_layer_input": 262144,
"hidden_size_per_layer_input": 256,
"num_kv_shared_layers": 20,
"use_double_wide_mlp": true,
"act_fn": "gelu_pytorch_tanh",
"norm_eps": 1e-06,
"post_attention_norm": true,
"post_ffn_norm": true,
"apply_embedding": true,
"embedding_scale_factor": 39.191835884530846,
"use_hf_rope": true,
"attention_qkv_bias": false,
"attention_type": "gemma4_mha",
"attention_multiplier": 1.0,
"final_logit_softcapping": 30.0,
"use_qk_norm": true,
"qk_norm_before_rope": true,
"sliding_window": 512,
"layer_types": ["sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention"],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
40 changes: 40 additions & 0 deletions examples/models/gemma4/config/e4b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"dim": 2560,
"hidden_dim": 10240,
"n_layers": 42,
"n_heads": 8,
"n_kv_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"vocab_size_per_layer_input": 262144,
"hidden_size_per_layer_input": 256,
"num_kv_shared_layers": 18,
"use_double_wide_mlp": false,
"act_fn": "gelu_pytorch_tanh",
"norm_eps": 1e-06,
"post_attention_norm": true,
"post_ffn_norm": true,
"apply_embedding": true,
"embedding_scale_factor": 50.59644256269407,
"use_hf_rope": true,
"attention_qkv_bias": false,
"attention_type": "gemma4_mha",
"attention_multiplier": 1.0,
"final_logit_softcapping": 30.0,
"use_qk_norm": true,
"qk_norm_before_rope": true,
"sliding_window": 512,
"layer_types": ["sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention"],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
157 changes: 157 additions & 0 deletions examples/models/gemma4/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import argparse
import json
import os
from typing import Dict

import torch
from executorch.examples.models.checkpoint import (
get_mapped_key,
load_checkpoint_from_pytorch_model,
)


_GEMMA4_TO_EXECUTORCH = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.embed_tokens_per_layer.weight": "embed_tokens_per_layer.weight",
"model.per_layer_model_projection.weight": "per_layer_model_projection.weight",
"model.per_layer_projection_norm.weight": "per_layer_projection_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight",
"model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_norm.weight",
"model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_ffn_norm.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.layer_scalar": "layers.{}.layer_scalar",
"model.layers.{}.per_layer_input_gate.weight": "layers.{}.per_layer_input_gate.weight",
"model.layers.{}.per_layer_projection.weight": "layers.{}.per_layer_projection.weight",
"model.layers.{}.post_per_layer_input_norm.weight": "layers.{}.post_per_layer_input_norm.weight",
}


_IGNORED_UNMAPPED_SUFFIXES = (
"rotary_emb.inv_freq",
"self_attn.v_norm.weight",
)


def _load_checkpoint_from_safetensors(input_dir: str) -> Dict:
from safetensors.torch import load_file

index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

merged_state_dict = {}
shard_to_weight_names = {}
for weight_name, shard in weight_map.items():
shard_to_weight_names.setdefault(shard, []).append(weight_name)

for shard in checkpoint_shards:
shard_weights = load_file(os.path.join(input_dir, shard))
for weight_name in shard_to_weight_names[shard]:
merged_state_dict[weight_name] = shard_weights[weight_name]
return merged_state_dict

model_path = os.path.join(input_dir, "model.safetensors")
if os.path.exists(model_path):
return load_file(model_path)

raise FileNotFoundError(f"Could not find safetensors checkpoint in {input_dir}")


def load_checkpoint(input_dir: str) -> Dict:
try:
print("Loading checkpoint from pytorch_model directory")
return load_checkpoint_from_pytorch_model(input_dir)
except FileNotFoundError:
print(
"Could not find pytorch_model checkpoints in directory, trying safetensors"
)

try:
print("Loading checkpoint from safetensors directory")
return _load_checkpoint_from_safetensors(input_dir)
except FileNotFoundError:
pass

raise FileNotFoundError(f"Could not find checkpoint in {input_dir}")


def gemma4_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
converted_state_dict = {}

for key, value in state_dict.items():
normalized_key = key
if normalized_key.startswith("model.language_model."):
normalized_key = normalized_key.replace("model.language_model.", "model.", 1)

if not normalized_key.startswith(
(
"model.layers.",
"model.embed_tokens.",
"model.embed_tokens_per_layer.",
"model.per_layer_model_projection.",
"model.per_layer_projection_norm.",
"model.norm.",
"lm_head.",
)
):
continue

try:
new_key = get_mapped_key(normalized_key, _GEMMA4_TO_EXECUTORCH)
except Exception as err:
if normalized_key.endswith(_IGNORED_UNMAPPED_SUFFIXES):
continue
raise ValueError(
f"Unexpected checkpoint key not mapped for Gemma4 export: {key}"
) from err
converted_state_dict[new_key] = value

if "output.weight" not in converted_state_dict:
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
state_dict = load_checkpoint(input_dir)
print("Converting checkpoint...")
state_dict = gemma4_to_meta(state_dict)
print("Saving checkpoint...")
torch.save(state_dict, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Convert Gemma4 weights to ExecuTorch meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing safetensor or PyTorch checkpoint files.",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
Loading
Loading