diff --git a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md index 10651f7186..87e6e834cd 100644 --- a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md +++ b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md @@ -20,7 +20,9 @@ DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The * DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance. -* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency. +* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency. + +* DeepSeek-V3.2 introduces DeepSeek Sparse Attention (DSA), successfully reduces computational complexity while preserving model performance in long-context scenarios. * DeepSeek R1 also uses V3 architecture. It utilizes cold-start data and large-scale reinforcement learning to incentivize chain-of-thought reasoning without relying solely on supervised fine-tuning. @@ -54,12 +56,96 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ dataset_type=synthetic ``` +## Continued pre-training for V3.2 Sparse Attention +**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-$k$ tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**. + +1. **Dense Warmup Stage** +The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen. +```sh +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ + model_name=deepseek3.2-671b \ + run_name=matmul_pre_training \ + per_device_batch_size=4 \ + enable_checkpointing=false \ + model_name=deepseek3-671b \ + ici_fsdp_parallelism=128 \ + steps=5 \ + tokenizer_path=deepseek-ai/DeepSeek-V3.2 \ + async_checkpointing=false \ + tokenizer_type=huggingface \ + tokenizer_path=deepseek-ai/DeepSeek-V3.2 \ + attention=flash \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + megablox=True \ + sparse_matmul=True \ + dataset_type=synthetic \ + indexer_sparse_training=False \ + indexer_loss_scaling_factor=0.01 \ # Must be non-zero to activate indexer training. Default in base.yaml is 0. + trainable_parameters_mask=['.*indexer.*'] +``` +2. **Sparse Training Stage** +The indexer is trained with sparse indexer loss, while the remaining model parameters are unfrozen and updated using standard language modeling loss. +```sh +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + model_name=deepseek3.2-671b \ + per_device_batch_size=4 \ + enable_checkpointing=false \ + model_name=deepseek3-671b \ + ici_fsdp_parallelism=128 \ + steps=5 \ + max_target_length=1024 \ + async_checkpointing=false \ + tokenizer_type=huggingface \ + tokenizer_path=deepseek-ai/DeepSeek-V3.2 \ + attention=flash \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + megablox=True \ + sparse_matmul=True \ + dataset_type=synthetic \ + indexer_sparse_training=True \ + indexer_loss_scaling_factor=0.01 \ # Must be non-zero to activate indexer training. Default in base.yaml is 0. +``` ## Checkpoint conversion To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16: * run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly. * run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding. +## Checkpoint conversion for V3.2 +> **Note:** These steps are required because Transformers code for V3.2 is not yet available. + +### 1. Download Model Weights +Download the Hugging Face weights from [deepseek-ai/DeepSeek-V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) to your local environment. Weights are provided in FP8. +`hf download deepseek-ai/DeepSeek-V3.2 --local-dir ` + +### 2. Dequantize Weights +* **Script:** +Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU: + +python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path= --output-bf16-hf-path= + +Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU. + +### 3. Convert to MaxText-compatible Orbax format +Execute the following command to finalize the conversion. Ensure your environment variables (`$BASE_OUTPUT_PATH`, `$HF_TOKEN`, and `$DEQUANTIZED_LOCAL_WEIGHTS`) are exported before running. +Setting `scan_layers=true` generates scanned Orbax format for training and fine-tuning. Setting `scan_layers=false` unscanned format in Orbax for decoding. +```bash +python3 -m maxtext.checkpoint_conversion.to_maxtext \ + src/maxtext/configs/base.yml \ + model_name=deepseek3.2-671b \ + scan_layers=true \ + attention=dot_product \ + base_output_directory=$BASE_OUTPUT_PATH \ + hf_access_token=$HF_TOKEN \ + hardware=cpu \ + skip_jax_distributed_system=True \ + --hf_model_path=$DEQUANTIZED_LOCAL_WEIGHTS \ + --eager_load_method=safetensors \ + --save_dtype=bfloat16 +``` ## Fine-tuning