Skip to content

DeepSeek V3.2 user guide update#3565

Open
snehalv2002 wants to merge 1 commit intomainfrom
ds3.2-xlml-tests
Open

DeepSeek V3.2 user guide update#3565
snehalv2002 wants to merge 1 commit intomainfrom
ds3.2-xlml-tests

Conversation

@snehalv2002
Copy link
Copy Markdown
Collaborator

Updating the user guide for DeepSeek-V3.2. Explains new feature updates and updates instructions on multi-stage lightning indexer training and checkpoint conversion.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 3, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This pull request updates the DeepSeek user guide to include instructions for the new DeepSeek-V3.2 model, specifically focusing on indexer training and checkpoint conversion. The updates are timely and provide clear steps for users to leverage the latest sparse attention features.

🔍 General Feedback

  • Consistency: Ensure that the model names (deepseek3.2-671b) and tokenizer paths (deepseek-ai/DeepSeek-V3.2) are consistent across all stages of the guide.
  • Syntax: Be careful with trailing backslashes in shell command examples, as they can cause errors if users copy-paste the last line.
  • Clarity: Using concrete example values (like 0.1 for scaling factors) is generally more user-friendly than placeholders in curly braces.

* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.
* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.

* DeepSeek-V3.2 replaces vanilla attention (O[L^2] where L is number of tokens) with DeepSeek Sparse Attention (O[L * k] where k is some number of sparsely selected tokens).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Use standard Big O notation with parentheses instead of square brackets.
Suggested change
* DeepSeek-V3.2 replaces vanilla attention (O[L^2] where L is number of tokens) with DeepSeek Sparse Attention (O[L * k] where k is some number of sparsely selected tokens).
* DeepSeek-V3.2 replaces vanilla attention (O(L^2) where L is number of tokens) with DeepSeek Sparse Attention (O(L * k) where k is some number of sparsely selected tokens).

```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
run_name=matmul_pre_training \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Use the `deepseek3.2-671b` model config, as it contains the necessary indexer configuration for this training.
Suggested change
run_name=matmul_pre_training \
model_name=deepseek3.2-671b \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

ici_fsdp_parallelism=128 \
steps=5 \
max_target_length=1024 \
async_checkpointing=false \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Use the correct tokenizer path for DeepSeek-V3.2.
Suggested change
async_checkpointing=false \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3 \
attention=flash \
dtype=bfloat16 \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Provide a concrete example value (like 0.1) instead of a placeholder in curly braces, as placeholders can be confusing in documentation examples.
Suggested change
dtype=bfloat16 \
indexer_loss_scaling_factor=0.1 \

python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
run_name=matmul_pre_training \
per_device_batch_size=4 \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Use the `deepseek3.2-671b` model config for the sparse training stage as well.
Suggested change
per_device_batch_size=4 \
model_name=deepseek3.2-671b \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

steps=5 \
max_target_length=1024 \
async_checkpointing=false \
tokenizer_type=huggingface \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Same as above, use the V3.2 tokenizer.
Suggested change
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=False \
sparse_matmul=False \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The command should not end with a trailing backslash if it is the last line. Additionally, for indexer-only training, the `trainable_parameters_mask` should be present in both stages to isolate the indexer.
Suggested change
sparse_matmul=False \
indexer_loss_scaling_factor=0.1 \
trainable_parameters_mask=['.*indexer.*']

* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.
* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.

* DeepSeek-V3.2 replaces vanilla attention (O[L^2] where L is number of tokens) with DeepSeek Sparse Attention (O[L * k] where k is some number of sparsely selected tokens).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of "replace vanilla attention", it would be better to say "improves MLA attention". The complexity is still O(L^2) but the indexer is added on top of MLA attention that Deepseek uses from V3 onwards.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 Let's mention something similar like bellow, and please feel free to modify:

DeepSeek-V3.2 introduces DeepSeek Sparse Attention (DSA), successfully reduces computational complexity while preserving model performance in long-context scenarios.

Let's remove the complexity to avoid any confusion, as Indexer also has L^2 for selection. We could direct readers to paper.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with hyperlink to paper: https://arxiv.org/pdf/2512.02556

dataset_type=synthetic
```

## Indexer training
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Highlight that this is only for V3.2 Sparse Attention in the heading itself

sparse_matmul=False \
dataset_type=synthetic \
indexer_sparse_training=False \
indexer_loss_scaling_factor={some non-zero value} \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with default value in base.yml. And add a comment saying can replace with non-zero value.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could put a small value, like 0.01

sparse_matmul=False \
dataset_type=synthetic \
indexer_sparse_training=True \
indexer_loss_scaling_factor={some non-zero value} \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as comment above

Comment on lines +103 to +104
megablox=False \
sparse_matmul=False \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have this set to True in the sparse training stage. These flags control which MoE strategy to use.

max_target_length=1024 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3 \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a difference in V3 vs V3.2 tokenizer path in HF? If not then this is fine.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No difference, but let's update to v3.2 to avoid confusion

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use tokenizer_path=deepseek-ai/DeepSeek-V3.2

## Indexer training
DeepSeek-V3.2 introduces deepseek sparse attention. Training the lightning indexer to achieve sparsity is a 2 stage process.

1. **Dense Warmup Stage**
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include a comment that in dense warmup stage, all model weights are frozen except the indexer weights.

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your 1st PR!!!

One more thing, could you update the PR desperation to follow our default template? One example: here

* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.
* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.

* DeepSeek-V3.2 replaces vanilla attention (O[L^2] where L is number of tokens) with DeepSeek Sparse Attention (O[L * k] where k is some number of sparsely selected tokens).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 Let's mention something similar like bellow, and please feel free to modify:

DeepSeek-V3.2 introduces DeepSeek Sparse Attention (DSA), successfully reduces computational complexity while preserving model performance in long-context scenarios.

Let's remove the complexity to avoid any confusion, as Indexer also has L^2 for selection. We could direct readers to paper.

sparse_matmul=False \
dataset_type=synthetic \
indexer_sparse_training=False \
indexer_loss_scaling_factor={some non-zero value} \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could put a small value, like 0.01

max_target_length=1024 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3 \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use v3.2 tokenizer path

attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=False \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use sparse_matmul=True and megablox=True

max_target_length=1024 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3 \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No difference, but let's update to v3.2 to avoid confusion

* **Target Directory:** `LOCAL_WEIGHTS`

### 2. Dequantize Weights
Convert the weights from FP8 to BF16 using the official DeepSeek script.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuningjin could you help check this part?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add a section on decoding for v3.2?

Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! Might be good to organize deepseek3.2 into a self-contained section for clarity, and add more explanation on continued pre-training for indexer.

@@ -1 +1 @@
<!--
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer that we organize all commands of deepseek3.2 under one section for clarity. Like

## DeepSeek V3.2
### Checkpoint conversion
### Indexer training
### Decode

* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.
* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.

* DeepSeek-V3.2 replaces vanilla attention (O[L^2] where L is number of tokens) with DeepSeek Sparse Attention (O[L * k] where k is some number of sparsely selected tokens).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with hyperlink to paper: https://arxiv.org/pdf/2512.02556

ici_fsdp_parallelism=128 \
steps=5 \
max_target_length=1024 \
async_checkpointing=false \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
run_name=matmul_pre_training \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
run_name=matmul_pre_training \
per_device_batch_size=4 \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes


### 1. Download Model Weights
Download the Hugging Face weights from [deepseek-ai/DeepSeek-V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) to your local environment.
* **Target Directory:** `LOCAL_WEIGHTS`
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better to be specific

The model weights are quantized in FP8.
hf download deepseek-ai/DeepSeek-V3.2 --local-dir <local_fp8_path>

### 2. Dequantize Weights
Convert the weights from FP8 to BF16 using the official DeepSeek script.
* **Script:** [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py)
* **Output Directory:** `DEQUANTIZED_LOCAL_WEIGHTS`
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU:

python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path=<local_fp8_path> --output-bf16-hf-path=<local_bf16_path>

Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU.

* **Script:** [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py)
* **Output Directory:** `DEQUANTIZED_LOCAL_WEIGHTS`

### 3. Convert to MaxText
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert to MaxText-compatible Orbax format

--hf_model_path=$DEQUANTIZED_LOCAL_WEIGHTS \
--eager_load_method=safetensors \
--save_dtype=bfloat16
```
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to add

Setting `scan_layers=true` generates scanned Orbax format for training and fine-tuning.  Setting `scan_layers=false` unscanned format in Orbax for decoding. 

dataset_type=synthetic
```

## Indexer training
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to elaborate the training, ref. Here are some suggestions:

(1) Indexer training -> Continued pre-training

(2)

**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-$k$ tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.

(3)

1. Dense Warm-up Stage: 

The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen.

(4)

2. Sparse Training Stage: 

The indexer is trained with sparse indexer loss, while the remaining model parameters are unfrozen and updated using standard language modeling loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants