Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion tests/end_to_end/tpu/deepseek/Run_DeepSeek.md
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add a section on decoding for v3.2?

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The

* DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance.

* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.
* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.

* DeepSeek-V3.2 replaces vanilla attention (O[L^2] where L is number of tokens) with DeepSeek Sparse Attention (O[L * k] where k is some number of sparsely selected tokens).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of "replace vanilla attention", it would be better to say "improves MLA attention". The complexity is still O(L^2) but the indexer is added on top of MLA attention that Deepseek uses from V3 onwards.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 Let's mention something similar like bellow, and please feel free to modify:

DeepSeek-V3.2 introduces DeepSeek Sparse Attention (DSA), successfully reduces computational complexity while preserving model performance in long-context scenarios.

Let's remove the complexity to avoid any confusion, as Indexer also has L^2 for selection. We could direct readers to paper.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with hyperlink to paper: https://arxiv.org/pdf/2512.02556

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Use standard Big O notation with parentheses instead of square brackets.
Suggested change
* DeepSeek-V3.2 replaces vanilla attention (O[L^2] where L is number of tokens) with DeepSeek Sparse Attention (O[L * k] where k is some number of sparsely selected tokens).
* DeepSeek-V3.2 replaces vanilla attention (O(L^2) where L is number of tokens) with DeepSeek Sparse Attention (O(L * k) where k is some number of sparsely selected tokens).


* DeepSeek R1 also uses V3 architecture. It utilizes cold-start data and large-scale reinforcement learning to incentivize chain-of-thought reasoning without relying solely on supervised fine-tuning.

Expand Down Expand Up @@ -54,12 +56,91 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
dataset_type=synthetic
```

## Indexer training
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Highlight that this is only for V3.2 Sparse Attention in the heading itself

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to elaborate the training, ref. Here are some suggestions:

(1) Indexer training -> Continued pre-training

(2)

**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-$k$ tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.

(3)

1. Dense Warm-up Stage: 

The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen.

(4)

2. Sparse Training Stage: 

The indexer is trained with sparse indexer loss, while the remaining model parameters are unfrozen and updated using standard language modeling loss.

DeepSeek-V3.2 introduces deepseek sparse attention. Training the lightning indexer to achieve sparsity is a 2 stage process.

1. **Dense Warmup Stage**
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include a comment that in dense warmup stage, all model weights are frozen except the indexer weights.

```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
run_name=matmul_pre_training \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Use the `deepseek3.2-671b` model config, as it contains the necessary indexer configuration for this training.
Suggested change
run_name=matmul_pre_training \
model_name=deepseek3.2-671b \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

per_device_batch_size=4 \
enable_checkpointing=false \
model_name=deepseek3-671b \
ici_fsdp_parallelism=128 \
steps=5 \
max_target_length=1024 \
async_checkpointing=false \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Use the correct tokenizer path for DeepSeek-V3.2.
Suggested change
async_checkpointing=false \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3 \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use v3.2 tokenizer path

attention=flash \
dtype=bfloat16 \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Provide a concrete example value (like 0.1) instead of a placeholder in curly braces, as placeholders can be confusing in documentation examples.
Suggested change
dtype=bfloat16 \
indexer_loss_scaling_factor=0.1 \

weight_dtype=bfloat16 \
megablox=False \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use sparse_matmul=True and megablox=True

sparse_matmul=False \
dataset_type=synthetic \
indexer_sparse_training=False \
indexer_loss_scaling_factor={some non-zero value} \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with default value in base.yml. And add a comment saying can replace with non-zero value.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could put a small value, like 0.01

trainable_parameters_mask=['.*indexer.*']
```
2. **Sparse Training Stage**
```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
run_name=matmul_pre_training \
per_device_batch_size=4 \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Use the `deepseek3.2-671b` model config for the sparse training stage as well.
Suggested change
per_device_batch_size=4 \
model_name=deepseek3.2-671b \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

enable_checkpointing=false \
model_name=deepseek3-671b \
ici_fsdp_parallelism=128 \
steps=5 \
max_target_length=1024 \
async_checkpointing=false \
tokenizer_type=huggingface \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Same as above, use the V3.2 tokenizer.
Suggested change
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

tokenizer_path=deepseek-ai/DeepSeek-V3 \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a difference in V3 vs V3.2 tokenizer path in HF? If not then this is fine.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No difference, but let's update to v3.2 to avoid confusion

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use tokenizer_path=deepseek-ai/DeepSeek-V3.2

attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=False \
sparse_matmul=False \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The command should not end with a trailing backslash if it is the last line. Additionally, for indexer-only training, the `trainable_parameters_mask` should be present in both stages to isolate the indexer.
Suggested change
sparse_matmul=False \
indexer_loss_scaling_factor=0.1 \
trainable_parameters_mask=['.*indexer.*']

Comment on lines +103 to +104
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have this set to True in the sparse training stage. These flags control which MoE strategy to use.

dataset_type=synthetic \
indexer_sparse_training=True \
indexer_loss_scaling_factor={some non-zero value} \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as comment above

```

## Checkpoint conversion
To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16:
* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly.
* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding.

## Checkpoint conversion for V3.2
> **Note:** These steps are required because Transformers code for V3.2 is not yet available.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove this note. I would suggest following

## Checkpoint conversion
For finetuning, supervised finetuning, or decoding, we need convert the checkpoint from HuggingFace.
1. To get started, follow the instructions at HuggingFace to download the model. For faster download, you may use `hf_xet` or `hf_transfer`. The model weights are quantized in MXFP4.
```
hf download [openai/gpt-oss-20b|openai/gpt-oss-120b] --local-dir <local_mxfp4_path>
```
2. Please convert it from MXFP4 to BF16 using script [dequantize_mxfp4.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/dequantize_mxfp4.py) on gpu.
```
python3 -m maxtext.checkpoint_conversion.standalone_scripts.dequantize_mxfp4 --input-path=<local_mxfp4_path> --output-path=<local_bf16_path> --dtype-str=bf16
```
3. Once downloaded and converted to BF16:
* run [convert_gpt_oss_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt_oss_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) scanned format for training and fine-tuning.
```
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gpt_oss_ckpt --base-model-path <local_bf16_path> \
--maxtext-model-path <GCS/path/to/scanned/maxtext/ckpt> --model-size [gpt-oss-20b|gpt-oss-120b]
```
* run [convert_gpt_oss_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt_oss_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding.
```
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gpt_oss_unscanned_ckpt --base-model-path <local_bf16_path> \
--maxtext-model-path <GCS/path/to/unscanned/maxtext/ckpt> --model-size [gpt-oss-20b|gpt-oss-120b]
```


### 1. Download Model Weights
Download the Hugging Face weights from [deepseek-ai/DeepSeek-V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) to your local environment.
* **Target Directory:** `LOCAL_WEIGHTS`
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better to be specific

The model weights are quantized in FP8.
hf download deepseek-ai/DeepSeek-V3.2 --local-dir <local_fp8_path>


### 2. Dequantize Weights
Convert the weights from FP8 to BF16 using the official DeepSeek script.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuningjin could you help check this part?

* **Script:** [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py)
* **Output Directory:** `DEQUANTIZED_LOCAL_WEIGHTS`
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU:

python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path=<local_fp8_path> --output-bf16-hf-path=<local_bf16_path>

Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU.


### 3. Convert to MaxText
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert to MaxText-compatible Orbax format

Execute the following command to finalize the conversion. Ensure your environment variables (`$BASE_OUTPUT_PATH`, `$HF_TOKEN`, and `$DEQUANTIZED_LOCAL_WEIGHTS`) are exported before running.

```bash
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml \
model_name=deepseek3.2-671b \
scan_layers=true \
attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH \
hf_access_token=$HF_TOKEN \
hardware=cpu \
skip_jax_distributed_system=True \
--hf_model_path=$DEQUANTIZED_LOCAL_WEIGHTS \
--eager_load_method=safetensors \
--save_dtype=bfloat16
```
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to add

Setting `scan_layers=true` generates scanned Orbax format for training and fine-tuning.  Setting `scan_layers=false` unscanned format in Orbax for decoding. 


## Fine-tuning

Expand Down
Loading