Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions tests/end_to_end/tpu/deepseek/Run_DeepSeek.md
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add a section on decoding for v3.2?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything different about running decoding from user perspective in V3.2?

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The

* DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance.

* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.
* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.

* DeepSeek-V3.2 introduces [DeepSeek Sparse Attention](https://arxiv.org/pdf/2512.02556) (DSA), successfully reduces computational complexity while preserving model performance in long-context scenarios.

* DeepSeek R1 also uses V3 architecture. It utilizes cold-start data and large-scale reinforcement learning to incentivize chain-of-thought reasoning without relying solely on supervised fine-tuning.

Expand Down Expand Up @@ -54,13 +56,11 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
dataset_type=synthetic
```


## Checkpoint conversion
To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16:
* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly.
* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding.


## Fine-tuning

After you have a MaxText compatible checkpoint, you could fine-tune it with different datasets.
Expand Down Expand Up @@ -216,3 +216,91 @@ To run MMLU benchmarks and validate the model's performance, follow the instruct
* Dropping implementation with flag `sparse_matmul=False` and reasonable `capacity_factor`, commonly used from 1 to 1.25.

See more examples in scripts for [V3](v3-671b/test_deepseek.sh) and [V2-Lite](v2-16b/test_deepseek.sh).

## DeepSeek-V3.2

### Continued pre-training for V3.2 Sparse Attention
**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-$k$ tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.

1. **Dense Warmup Stage**
The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen.
```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
model_name=deepseek3.2-671b \
run_name=matmul_pre_training \
per_device_batch_size=4 \
enable_checkpointing=false \
model_name=deepseek3-671b \
ici_fsdp_parallelism=128 \
steps=5 \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=True \
sparse_matmul=True \
dataset_type=synthetic \
indexer_sparse_training=False \
indexer_loss_scaling_factor=0.01 \ # Must be non-zero to activate indexer training. Default in base.yaml is 0.
trainable_parameters_mask=['.*indexer.*']
```
2. **Sparse Training Stage**
The indexer is trained with sparse indexer loss, while the remaining model parameters are unfrozen and updated using standard language modeling loss.
```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
model_name=deepseek3.2-671b \
per_device_batch_size=4 \
enable_checkpointing=false \
model_name=deepseek3-671b \
ici_fsdp_parallelism=128 \
steps=5 \
max_target_length=1024 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=True \
sparse_matmul=True \
dataset_type=synthetic \
indexer_sparse_training=True \
indexer_loss_scaling_factor=0.01 \ # Must be non-zero to activate indexer training. Default in base.yaml is 0.
```

### Checkpoint conversion for V3.2
> **Note:** These steps are required because Transformers code for V3.2 is not yet available.

#### 1. Download Model Weights
Download the Hugging Face weights from [deepseek-ai/DeepSeek-V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) to your local environment. Weights are provided in FP8.
`hf download deepseek-ai/DeepSeek-V3.2 --local-dir <local_fp8_path>`

#### 2. Dequantize Weights
* **Script:**
Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU:

python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path=<local_fp8_path> --output-bf16-hf-path=<local_bf16_path>

Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU.

#### 3. Convert to MaxText-compatible Orbax format
Execute the following command to finalize the conversion. Ensure your environment variables (`$BASE_OUTPUT_PATH`, `$HF_TOKEN`, and `$DEQUANTIZED_LOCAL_WEIGHTS`) are exported before running.
Setting `scan_layers=true` generates scanned Orbax format for training and fine-tuning. Setting `scan_layers=false` unscanned format in Orbax for decoding.
```bash
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml \
model_name=deepseek3.2-671b \
scan_layers=true \
attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH \
hf_access_token=$HF_TOKEN \
hardware=cpu \
skip_jax_distributed_system=True \
--hf_model_path=$DEQUANTIZED_LOCAL_WEIGHTS \
--eager_load_method=safetensors \
--save_dtype=bfloat16
```
56 changes: 56 additions & 0 deletions tests/end_to_end/tpu/deepseek/v3.2-671b/2_test_deepseek.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash

# This file is documentation for how to get started with DeepSeek v3.2.

# This file runs Step 2 on v5p-128 on a daily basis.
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
# Scanned format is better for training; unscanned format is better for decoding.
# 2. Run logit check, pre-training, fine-tuning, and decoding.

set -ex

export MODEL_NAME='deepseek3.2-671b'
export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3.2'

# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# e.g., $HOME/maxtext/src/maxtext
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set"
fi
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}

# Step 2:
# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands
# export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items
# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
# Use a hard-coded golden checkpoint, rather than checkpoints generated by Step 1 as it is not in daily test.
SCANNED_CKPT_PATH=gs://maxtext-deepseek/deepseek3.2/2026-02-20/scanned/0/items
UNSCANNED_CKPT_PATH=gs://maxtext-deepseek/deepseek3.2/2026-02-20/unscanned/0/items
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset

# Test whether the forward pass logits match the golden logits
# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl"
if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl"
GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
fi

# override deepseek3.2-671b.yml with indexer_topk=2
# OVERRIDE="indexer_topk=2"
# OVERRIDE = ""
python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1.5 --rtol=1.5 --max_kl_div=0.3

# Run decoding - tokamax_gmm implementation
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
python3 -m maxtext.inference.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "
7 changes: 7 additions & 0 deletions tests/utils/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,12 @@ def main(config, test_args): # pylint: disable=W0621
max_logging.log(f"\n--- Comparing forward pass for golden data index: {golden_data_index} ---")
ids, decoder_segment_ids, decoder_positions, golden_logits, seq_len, images = get_data(golden_data_point, config)
max_logging.log("maxtext forward pass")
max_logging.log(f"XXX model: {model}")
max_logging.log(f"XXX state.params: {state.params}")
max_logging.log(f"XXX ids: {ids}")
max_logging.log(f"XXX decoder_positions: {decoder_positions}")
max_logging.log(f"XXX decoder_segment_ids: {decoder_segment_ids}")
max_logging.log(f"XXX images: {images}")
full_train_logits = model.apply(
state.params,
ids,
Expand All @@ -292,6 +298,7 @@ def main(config, test_args): # pylint: disable=W0621
enable_dropout=False,
rngs={"aqt": init_rng},
)
max_logging.log(f"XXX FULL TRAIN LOGIT: {full_train_logits}")

full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits, tiled=True)
# if full_train_logits shape is [num_hosts, batch_size, seq_len, vocab_size]
Expand Down
Loading