diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 3294ba653..7e9c855cb 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -242,6 +242,17 @@ To add a system prompt, use the `--system_prompt ` argument. For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support. +### Configuring Draft Model + +For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. To use 2-layer eagle with 8192 intermediate size for MLP, set `eagle_config.json` to: + +```json +{ + "num_hidden_layers": 2, + "intermediate_size":8192 +} +``` + ### Draft Vocabulary Compression We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: @@ -252,15 +263,7 @@ python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. -### Configuring Draft Model - -For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`: - -```json -{ - "draft_vocab_size": 32000 -} -``` +Then, simply set `{"draft_vocab_size":32000}` in `eagle_config.json` and include `--draft_vocab_cache ` when running `./launch_train.sh`. The draft model will use this provided vocab table during training and export. ### Interact with `modelopt.torch.speculative` diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh index 48d12aeb2..debbe6881 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh @@ -19,5 +19,5 @@ python3 collect_hidden_states/compute_hidden_states_hf.py \ --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ + --input-data synthetic_conversations/daring-anteater.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh index 31e2294d9..dac0ab9a9 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh @@ -30,7 +30,7 @@ split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FI for i in $(seq 0 $((DP_SIZE-1))) do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & +CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-data /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & done wait diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh index 487d0d69d..75a27deb6 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh @@ -20,6 +20,6 @@ export TLLM_LOG_LEVEL="error"; python3 collect_hidden_states/compute_hidden_states_trtllm.py \ --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ + --input-data synthetic_conversations/daring-anteater.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh index 4b0fd1060..d06cfc061 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh @@ -33,7 +33,7 @@ split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FI for i in $(seq 0 $((DP_SIZE-1))) do -export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & +export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-data /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & done wait diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 3625072b1..3ef715637 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -14,8 +14,6 @@ # limitations under the License. import inspect -import json -import os from collections.abc import Callable from pathlib import Path from typing import TYPE_CHECKING @@ -29,16 +27,20 @@ import transformers from datasets import load_dataset from packaging.version import Version -from PIL import Image from scripts.ar_validate import validate_ar from torch.utils.data import Dataset -from transformers import AutoProcessor, Trainer, TrainerCallback +from transformers import Trainer, TrainerCallback from transformers.trainer_pt_utils import LabelSmoother import modelopt from modelopt.torch.speculative.utils import get_ttt_msk_func from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import is_master +from modelopt.torch.utils.plugins.transformers_dataset import ( + LanguageDataCollator, + ShardedDataset, + VisionLanguageDataCollator, +) try: import wandb @@ -47,459 +49,124 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index -REMOVE_THINK_CHAT_TEMPLATE = ( - "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}" -) - - -def preprocess(examples, tokenizer, **kwargs): - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") - new_examples = { - "input_ids": [], - "attention_mask": [], - "loss_mask": [], - "labels": [], - } - for i in range(len(examples)): - messages = [] - source = examples[i]["conversations"] - - # Detect format: either role/content or from/value - def get_role_content(item): - if "role" in item and "content" in item: - return item["role"], item["content"] - elif "from" in item and "value" in item: - return item["from"], item["value"] - else: - raise ValueError(f"Unknown conversation format: {item}") - - for sentence in source: - role, content = get_role_content(sentence) - messages.append({"role": role.lower(), "content": content}) - conversation = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ) - - output = tokenizer( - conversation, - return_tensors="pt", - add_special_tokens=False, - truncation=True, - ) - input_ids = output.input_ids[0] - attention_mask = output.attention_mask[0] - loss_mask = torch.ones_like(input_ids) - labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)]) - new_examples["input_ids"].append(input_ids) - new_examples["attention_mask"].append(attention_mask) - new_examples["loss_mask"].append(loss_mask) - new_examples["labels"].append(labels) - - return new_examples - - -def preprocess_vlm(examples, tokenizer, processor, img_dir): - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") - new_examples = { - "input_ids": [], - "attention_mask": [], - "loss_mask": [], - "labels": [], - "pixel_values": [], - "image_flags": [], - } - for i in range(len(examples)): - messages = [] - source = examples[i]["conversations"] - - # Detect format: either role/content or from/value - def get_role_content(item): - if "role" in item and "content" in item: - return item["role"], item["content"] - elif "from" in item and "value" in item: - return item["from"], item["value"] - else: - raise ValueError(f"Unknown conversation format: {item}") - - # align role to user-assistant format - def convert_role(role): - role_map = { - "human": "user", - "gpt": "assistant", - } - return role_map[role.lower()] if role.lower() in role_map else role.lower() - - for sentence in source: - role, content = get_role_content(sentence) - new_role = convert_role(role) - messages.append({"role": new_role, "content": content}) - conversation = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ) - - img_filename = os.path.join(img_dir, examples[i]["image"]) - img = Image.open(img_filename) - output = processor(images=img, text=conversation, return_tensors="pt") - input_ids = output.input_ids[0] - attention_mask = output.attention_mask[0] - loss_mask = torch.ones_like(input_ids) - labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)]) - # TODO: add labels and answer-only loss masking? - - new_examples["input_ids"].append(input_ids) - new_examples["attention_mask"].append(attention_mask) - new_examples["loss_mask"].append(loss_mask) - new_examples["labels"].append(labels) - new_examples["pixel_values"].append(output.pixel_values) - new_examples["image_flags"].append( - torch.ones((output.pixel_values.shape[0],), dtype=torch.int64) - ) - return new_examples +class OfflineSupervisedDataset(Dataset): + """Offline dataset for supervised fine-tuning. -class SupervisedDataset(Dataset): - """Dataset for supervised fine-tuning. + This dataset loads data on-the-fly from pre-processed .pt data files. Args: - raw_data (list): A list of raw data examples. - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. + dumped_files (list): A list of file paths to the dumped .pt files. """ def __init__( self, - raw_data, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, + dumped_files, ): super().__init__() - - print_rank_0("Formatting inputs...") - sources = raw_data - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess - self.data_dict = self.preprocess_fn( - sources, tokenizer, processor=vlm_processor, img_dir=img_dir - ) + self.dumped_files = dumped_files def __len__(self): - return len(self.data_dict["input_ids"]) + return len(self.dumped_files) def __getitem__(self, i) -> dict[str, torch.Tensor]: - return {k: self.data_dict[k][i] for k in self.data_dict} - - -class LazySupervisedDataset(Dataset): - """Lazy dataset for supervised fine-tuning. - - This dataset loads data on-the-fly when requested, which can be memory-efficient but slower. - - Args: - raw_data (list): A list of raw data examples. - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - """ - - def __init__( - self, - raw_data, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, - ): - super().__init__() - print_rank_0("Formatting inputs...Skip in lazy mode") - self.tokenizer = tokenizer - self.raw_data = raw_data - self.cached_data_dict = {} - self.vlm_processor = vlm_processor - self.img_dir = img_dir - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess + offline_data = torch.load(self.dumped_files[i]) - def __len__(self): - return len(self.raw_data) - - def __getitem__(self, i) -> dict[str, torch.Tensor]: - if i in self.cached_data_dict: - return self.cached_data_dict[i] - ret = self.preprocess_fn( - [self.raw_data[i]], self.tokenizer, processor=self.vlm_processor, img_dir=self.img_dir - ) - ret = {k: ret[k][0] for k in ret} - self.cached_data_dict[i] = ret + labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID) + labels[..., :-1] = offline_data["input_ids"][..., 1:] + ret = { + "input_ids": offline_data["input_ids"], + "base_model_hidden_states": offline_data["hidden_states"], + "aux_hidden_states": offline_data["aux_hidden_states"], + "attention_mask": torch.ones_like(offline_data["input_ids"]), + "loss_mask": torch.ones_like(offline_data["input_ids"]), + "labels": labels, + } return ret -class OfflineSupervisedDataset(Dataset): - """Lazy offline dataset for supervised fine-tuning. +class EagleOfflineDataCollator: + """Data collator that truncate or pads data for offline training.""" - This dataset loads data on-the-fly from pre-processed .pt data files as well as - input conversations in JSON format. + def __init__(self, train_len): + self.train_len = train_len - Args: - data_entries (list): A list of tuples (raw_data_example, file_path). - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - """ + def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0): + """Pad or truncate a tensor to length along a given dimension.""" + dim = dim % x.ndim # support negative dimension - def __init__( - self, - data_entries, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, - ): - super().__init__() - print_rank_0("Formatting inputs...Skip in offline mode") - self.tokenizer = tokenizer - self.data_entries = data_entries - self.vlm_processor = vlm_processor - self.img_dir = img_dir - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess + # allocate output tensor + out_shape = list(x.shape) + out_shape[dim] = length + out = x.new_zeros(out_shape) - # Does not cache the hidden states, as those have an extremely large memory footprint. - self.cached_data_dict = {} + # consturct copy slice + slc = [slice(None)] * x.ndim + slc[dim] = slice(0, min(length, x.size(dim))) - def __len__(self): - return len(self.data_entries) + # populate output tensor + out[tuple(slc)] = x[tuple(slc)] + return out - def __getitem__(self, i) -> dict[str, torch.Tensor]: - # Load the conversational data, using the cache - raw_data, offline_file_path = self.data_entries[i] - # Extend the data sample with the hidden states from the .pt file - max_length = self.tokenizer.model_max_length - offline_data = torch.load(offline_file_path) - offline_data["input_ids"] = offline_data["input_ids"][:max_length] - offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] - offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + base_batch = { + k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) + for k in ["input_ids", "attention_mask", "loss_mask", "labels"] + } - ret = { - "input_ids": offline_data["input_ids"], - "attention_mask": torch.ones_like(offline_data["input_ids"]), - "loss_mask": torch.ones_like(offline_data["input_ids"]), - "labels": torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID), - "kwargs": { - "base_model_outputs": { - "base_model_hidden_states": offline_data["hidden_states"], - "aux_hidden_states": offline_data["aux_hidden_states"], - } - }, + base_model_outputs = { + k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) + for k in ["base_model_hidden_states", "aux_hidden_states"] } - return ret + + batch = { + **base_batch, + "base_model_outputs": base_model_outputs, + } + return batch def make_eagle_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, - max_length=None, + train_len=None, ) -> dict: - """Make dataset and collator for supervised fine-tuning. - - Args: - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - data_args: Data arguments. + if data_args.offline_data_path is None: + train_dataset = ShardedDataset("json", data_files=data_args.data_path) + + if not data_args.vlm_processor: + data_collator = LanguageDataCollator( + tokenizer=tokenizer, + train_len=train_len, + return_labels=True, + ) + else: + data_collator = VisionLanguageDataCollator( + processor=data_args.vlm_processor, + train_len=train_len, + local_image_path=data_args.vlm_img_dir, + return_labels=True, + ) - Returns: - dict: A dictionary containing train and eval datasets. - """ - if data_args.vlm_processor: - vlm_processor = AutoProcessor.from_pretrained( - data_args.vlm_processor, trust_remote_code=True, use_fast=True - ) - vlm_img_dir = data_args.vlm_img_dir else: - vlm_processor, vlm_img_dir = None, None - # Load the conversations from the source file - print_rank_0("Loading input conversations...") - data_json = [] - data_path_p = Path(data_args.data_path) - if data_path_p.is_dir(): - # Load all .jsonl files in the directory and combine them - for jsonl_file in sorted(data_path_p.glob("*.jsonl")): - with open(jsonl_file) as f: - data_json.extend(json.loads(line) for line in f) - else: - with open(data_args.data_path) as f: - if data_args.data_path.endswith("jsonl"): - data_json = [json.loads(line) for line in f] - else: - data_json = json.load(f) - - if data_args.offline_data_path is not None: print_rank_0("Loading pre-processed data for offline training...") - dataset_cls = OfflineSupervisedDataset + assert not data_args.vlm_processor, "Offline data is not supported for VLM." - # Glob for all .pt files in the data_path directory - assert data_args.offline_data_path is not None, ( - "offline_data_path must be provided for offline training." - ) offline_data_path = Path(data_args.offline_data_path) - # Collect all pt file paths - all_files = {str(p) for p in offline_data_path.glob("*.pt")} - all_files |= {str(p) for p in offline_data_path.glob("**/*.pt")} - if not all_files: + dumped_files = [str(p) for p in offline_data_path.glob("*.pt")] + if not dumped_files: raise ValueError(f"No .pt files found in {data_args.offline_data_path}") - # Build a map from conv_id to file_path for fast lookup - print("building conv_id_to_file map...") - conv_id_to_file = {} - for pt_path in all_files: - pt_name = Path(pt_path).name - # Expect conv_id.pt - if pt_name.endswith(".pt"): - conv_id = pt_name[:-3] - conv_id_to_file[conv_id] = pt_path - - valid_entries = [] - print("filtering valid entries...") - for entry in data_json: - conv_id = entry.get("conversation_id") - if conv_id is None: - conv_id = entry.get("uuid") - if conv_id is None: - conv_id = entry.get("id") - if conv_id is None: - raise ValueError(f"Conversation ID required but not found for entry {entry}") - - file_path = conv_id_to_file.get(str(conv_id)) - if file_path is None: - continue - valid_entries.append((entry, file_path)) - - if len(valid_entries) == 0: - msg = """No valid files found in the offline data path that match the conversation IDs - in the provided data json. Please ensure that the offline data path is correct and - contains .pt files named after the conversation IDs, and that the input conversations - json has the correct format (with 'conversation_id' or 'id' fields).""" - raise ValueError(msg) - elif len(valid_entries) < len(data_json): - print_rank_0( - f"Warning: Only {len(valid_entries)} out of {len(data_json)} conversations" - " have corresponding .pt files in the offline data path. Continuing..." - ) - - num_train = int(len(valid_entries) * 0.95) - train_dataset = dataset_cls( - valid_entries[:num_train], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - eval_dataset = dataset_cls( - valid_entries[num_train:], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - - data_collator = DataCollatorForOffline(max_length=max_length) - else: - print_rank_0("Loading input conversations...") - dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset - - train_dataset = dataset_cls( - data_json[: int(len(data_json) * 0.95)], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - eval_dataset = dataset_cls( - data_json[int(len(data_json) * 0.95) :], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - - data_collator = DataCollatorWithPadding(max_length=max_length) + train_dataset = OfflineSupervisedDataset(dumped_files) + data_collator = EagleOfflineDataCollator(train_len=train_len) return { "train_dataset": train_dataset, - "eval_dataset": eval_dataset, "data_collator": data_collator, } -class DataCollatorWithPadding: - def __init__(self, max_length): - self.max_length = max_length - - def paddingtensor2d(self, intensors, length): - n, dim = intensors.shape - if n > length: - return intensors[:length, :] - padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor)) - return outtensors - - def paddingtensor(self, intensors, length): - if intensors.shape[0] > length: - return intensors[:length] - padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor)) - return outtensors - - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: - batch_input_ids = torch.stack( - [self.paddingtensor(item["input_ids"], self.max_length) for item in features] - ) - batch_attention_mask = torch.stack( - [self.paddingtensor(item["attention_mask"], self.max_length) for item in features] - ) - batch_loss_mask = torch.stack( - [self.paddingtensor(item["loss_mask"], self.max_length) for item in features] - ) - - batch_labels = torch.stack( - [self.paddingtensor(item["labels"], self.max_length) for item in features] - ) - - batch = { - "input_ids": batch_input_ids, - "attention_mask": batch_attention_mask, - "loss_mask": batch_loss_mask, - "labels": batch_labels, - } - - # Collate VLM data - if "pixel_values" in features[0]: - # pixel values and image flags should be flattened inside a batch - batch["pixel_values"] = torch.cat([item["pixel_values"] for item in features], dim=0) - batch["image_flags"] = torch.cat([item["image_flags"] for item in features], dim=0) - - return batch - - -class DataCollatorForOffline(DataCollatorWithPadding): - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: - base_batch = super().__call__(features) - if "kwargs" not in features[0]: - raise ValueError("No kwargs found in batch features. Offline data required.") - - features = [item["kwargs"]["base_model_outputs"] for item in features] - - batch_hidden_states = torch.stack( - [ - self.paddingtensor2d(item["base_model_hidden_states"], self.max_length) - for item in features - ] - ) - batch_aux_hidden_states = torch.stack( - [self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features] - ) - - batch = { - **base_batch, - "base_model_outputs": { - "base_model_hidden_states": batch_hidden_states, - "aux_hidden_states": batch_aux_hidden_states, - }, - } - - return batch - - class EagleTrainerWithAccLog(Trainer): """Wrapper around Trainer that logs training accuracy.""" diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c937d5b09..c0b9ea00e 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -102,6 +102,14 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi DP_SHARD_SIZE="${1#*=}" ;; + --log_steps*) + if [[ "$1" != *=* ]]; then shift; fi + LOG_STEPS="${1#*=}" + ;; + --draft_vocab_cache*) + if [[ "$1" != *=* ]]; then shift; fi + DRAFT_VOCAB_CACHE="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -138,6 +146,8 @@ AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} ESTIMATE_AR=${ESTIMATE_AR:-False} CP_SIZE=${CP_SIZE:-1} DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))} +LOG_STEPS=${LOG_STEPS:-100} +DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" @@ -179,6 +189,13 @@ else fi +if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then + DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE" +else + DRAFT_VOCAB_CACHE_ARGS="" +fi + + # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False CMD="accelerate launch --mixed_precision bf16 main.py \ @@ -201,12 +218,13 @@ CMD="accelerate launch --mixed_precision bf16 main.py \ --weight_decay 0.0 \ --warmup_steps 100 \ --lr_scheduler_type linear \ - --logging_steps 100 \ + --logging_steps $LOG_STEPS \ --tf32 True \ --data_path $DATA \ --disable_tqdm $DISABLE_TQDM \ --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ + $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 8706ca049..a880148a7 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -48,6 +48,7 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs from modelopt.torch.utils import print_rank_0 torch.manual_seed(0) @@ -76,9 +77,9 @@ class DataArguments: }, ) lazy_preprocess: bool = True - draft_vocab_cache_dir: str = field( - default="draft_vocab_cache", - metadata={"help": "Path to the d2t cache directory."}, + draft_vocab_cache: str | None = field( + default=None, + metadata={"help": "Path to d2t.pt cache file."}, ) vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) @@ -97,7 +98,7 @@ class TrainingArguments(transformers.TrainingArguments): ) dataloader_drop_last: bool = field(default=True) bf16: bool = field(default=True) - mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3" + mode: Literal["eagle3", "medusa"] = "eagle3" estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR during training for logging."} ) @@ -147,22 +148,21 @@ def train(): training_args.parallelism_config.sp_backend = None print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}") - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir): - last_checkpoint = get_last_checkpoint(training_args.output_dir) + # Detect checkpoint to resume from + last_checkpoint = ( + get_last_checkpoint(training_args.output_dir) + if os.path.isdir(training_args.output_dir) + else None + ) + if last_checkpoint: print_rank_0(f"Last checkpoint detected: {last_checkpoint}") - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint + checkpoint = training_args.resume_from_checkpoint or last_checkpoint use_offline_training = data_args.offline_data_path is not None if checkpoint: - model = transformers.AutoModelForCausalLM.from_pretrained( + _, model = load_vlm_or_llm_with_kwargs( checkpoint, torch_dtype="auto", trust_remote_code=True ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) @@ -170,7 +170,7 @@ def train(): # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model = transformers.AutoModelForCausalLM.from_pretrained( + model_config, model = load_vlm_or_llm_with_kwargs( model_args.model_name_or_path, torch_dtype="auto", device_map="cpu", @@ -180,79 +180,48 @@ def train(): if use_offline_training: # When doing offline training, we need to set num_hidden_layers # since we override it when loading the model for space savings - model_config = transformers.AutoConfig.from_pretrained( - model_args.model_name_or_path, trust_remote_code=True - ) model.config.num_orig_hidden_layers = model_config.num_hidden_layers tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, trust_remote_code=True, ) - if tokenizer.chat_template is None: - tokenizer.chat_template = ( - "{%- for message in messages %}" - "{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}" - "{%- endfor %}" - ) - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id - if training_args.mode == "medusa": config = { "medusa_num_heads": medusa_args.medusa_num_heads, "medusa_num_layers": medusa_args.medusa_num_layers, } mtsp.convert(model, [("medusa", config)]) - elif training_args.mode in ["eagle1", "eagle3"]: - from modelopt.torch.speculative.config import ( - default_eagle_config, - eagle3_default_config, - kimik2_eagle_default_config, + elif training_args.mode == "eagle3": + custom_config = ( + json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {} ) - if eagle_args.eagle_decoder_type == "kimik2": - eagle_architecture_config = kimik2_eagle_default_config - else: - eagle_architecture_config = { - "eagle1": default_eagle_config, - "eagle3": eagle3_default_config, - }[training_args.mode] - - if eagle_args.eagle_config: - with open(eagle_args.eagle_config) as f: - custom_config = json.load(f) - eagle_architecture_config.update(custom_config) - config = { "eagle_decoder_type": eagle_args.eagle_decoder_type, "eagle_offline": use_offline_training, - "eagle_architecture_config": eagle_architecture_config, + "eagle_architecture_config": custom_config, } mtsp.convert(model, [("eagle", config)]) # read draft vocab cache if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: - try: - model_name = os.path.basename(os.path.normpath(model_args.model_name_or_path)) - vocab_cache_path = os.path.join( - data_args.draft_vocab_cache_dir, model_name, "d2t.pt" + if not os.path.isfile(data_args.draft_vocab_cache): + raise FileNotFoundError( + f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" ) - vocab_cache = torch.load(vocab_cache_path) - model.eagle_module.d2t = vocab_cache - print_rank_0(f"Loaded draft vocab cache from {vocab_cache_path}.") - except Exception as e: - raise e + model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) + print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") if training_args.mode == "medusa": data_module = make_medusa_supervised_data_module(tokenizer, data_args) - elif training_args.mode in ["eagle1", "eagle3"]: + elif training_args.mode == "eagle3": data_module = make_eagle_supervised_data_module( - tokenizer, data_args, max_length=training_args.training_seq_len + tokenizer, data_args, train_len=training_args.training_seq_len ) trainer = EagleTrainerWithAccLog( diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index 38b886693..d5c37a895 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -18,10 +18,11 @@ from accelerate import Accelerator from datasets import load_dataset from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer import modelopt.torch.opt as mto from modelopt.torch.speculative.plugins.transformers import HFARValidation +from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs mto.enable_huggingface_checkpointing() @@ -71,7 +72,7 @@ def main(): accelerator = Accelerator() # Load model and tokenizer - model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto") + _, model = load_vlm_or_llm_with_kwargs(args.model_path, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model.eval() model = accelerator.prepare(model) diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index dfc293ee9..fc3421583 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -18,10 +18,10 @@ import argparse import torch -from transformers import AutoModelForCausalLM import modelopt.torch.opt as mto from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs def parse_args(): @@ -38,11 +38,11 @@ def parse_args(): mto.enable_huggingface_checkpointing() args = parse_args() -model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto") +_, model = load_vlm_or_llm_with_kwargs(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): export_hf_checkpoint( - model, # The quantized model. - export_dir=args.export_path, # The directory where the exported files will be stored. + model, + export_dir=args.export_path, ) print(f"Exported checkpoint to {args.export_path}") diff --git a/modelopt/torch/speculative/eagle/conversion.py b/modelopt/torch/speculative/eagle/conversion.py index ffaa195f2..2b085d5e3 100644 --- a/modelopt/torch/speculative/eagle/conversion.py +++ b/modelopt/torch/speculative/eagle/conversion.py @@ -20,6 +20,7 @@ from modelopt.torch.opt.conversion import ModelLikeModule from modelopt.torch.opt.dynamic import _DMRegistryCls from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.speculative.config import eagle3_default_config, kimik2_eagle_default_config from ..config import EagleConfig @@ -38,6 +39,14 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu EagleDMRegistry.register({original_cls: "base_model_class"})(EagleDMRegistry[cls]) break + # merge custom config with default config + default_arch_config = { + "llama": eagle3_default_config, + "kimik2": kimik2_eagle_default_config, + }[config.eagle_decoder_type] + custom_config = config.eagle_architecture_config + config.eagle_architecture_config = {**default_arch_config, **custom_config} + eagle_model = EagleDMRegistry.convert(model) eagle_model.modify( eagle_offline=config.eagle_offline, diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 3090297aa..f8b7e33df 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -31,6 +31,7 @@ import contextlib import copy +from dataclasses import dataclass from typing import Any import torch @@ -244,7 +245,6 @@ def __init__(self, config, decoder_layer_cls, bias=False): assert config.draft_vocab_size <= config.vocab_size, ( "EAGLE module's vocab size should be <= base model vocab size!" ) - # Initialize the buffers to zero. # Their values depend on specific tokenzier and calibrate dataset, and should be set in training script. if config.draft_vocab_size < config.vocab_size: @@ -405,6 +405,25 @@ def forward( return post_norm_h, pre_norm_h, past_key_values +@dataclass +class EagleBaseModelOutput: + out_hiddens: torch.Tensor + aux_hiddens: torch.Tensor | None = None + logits: torch.Tensor | None = None + input_embeds: torch.Tensor | None = None + loss: torch.Tensor | None = None + + @classmethod + def from_offline_dict(cls, d: dict): + return cls( + out_hiddens=d.get("base_model_hidden_states"), + aux_hiddens=d.get("aux_hidden_states"), + logits=d.get("base_model_logits"), + input_embeds=d.get("base_model_input_embeds"), + loss=None, + ) + + @EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFEagleModel(EagleModel): """Eagle Model Class for huggingface models.""" @@ -425,16 +444,26 @@ def _base_model_lm_head(self): @property def _base_llm_config(self): """Return the llm config for the base model, from LLM or VLM.""" - return self.config.llm_config if hasattr(self.config, "llm_config") else self.config + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { - "base_model_path": ["model", "backbone", "language_model.backbone"], + "base_model_path": [ + "model.language_model", + "model", + "backbone", + "language_model.backbone", + ], "base_model_embeddings_path": [ "model.embed_tokens", "backbone.embeddings", "language_model.backbone.embeddings", + "model.language_model.embed_tokens", ], "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], } @@ -480,6 +509,8 @@ def _collect_aux_hidden_states_forward_hook(self, module, input, output) -> None def pop_and_gather_aux_hiddens(self): """Pop auxiliary hidden states from base model and gather them on the draft model device.""" + if not self.eagle_config.use_aux_hidden_state: + return None # In PTQ, forward method will be called with try and except to find max batch size. # This leads to uncleared aux hidden states in the front of the list. # To fix it, we only return the last num_aux_h items in the list. @@ -488,9 +519,11 @@ def pop_and_gather_aux_hiddens(self): self._aux_hidden_states.clear() # Gather aux hidden states on the draft model device - aux_h_list = [h.to(self.eagle_module.fc.weight.device) for h in aux_h_list] + aux_hiddens = torch.cat( + [h.to(self.eagle_module.fc.weight.device) for h in aux_h_list], dim=-1 + ) - return aux_h_list + return aux_hiddens def _get_eagle_device(self): """Return the device where we should place eagle module.""" @@ -559,21 +592,13 @@ def modify( ): self.config.quantization_config.quantization_config.ignore.append("re:.*eagle_module.*") - # Use default aux_hidden_state layers if use_aux_hidden_state is True - # but no layer id is given + # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0 ): self._set_default_aux_hidden_state_layers() - if self._base_llm_config.hidden_size != self.eagle_config.hidden_size: - raise ValueError( - "EAGLE module hidden size " - f"{self.eagle_config.hidden_size} must match base model hidden size " - f"{self._base_llm_config.hidden_size}!" - ) - # Freeze all parameters if self.eagle_freeze_base_model: for name, param in self.named_parameters(): @@ -615,25 +640,26 @@ def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): return self._cached_attn_blk_masks[ttt_step] def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length + self, attention_mask, input_shape, past_key_values_length, device, dtype ): """Expand the 2-D attention mask to 4-D and apply causal mask.""" # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None + # construct causal mask if input_shape[-1] > 1: combined_attention_mask = make_causal_mask( input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, + dtype, + device=device, past_key_values_length=past_key_values_length, ) - + # merge causal mask with padding mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) + expanded_attn_mask = expand_mask(attention_mask, dtype, tgt_len=input_shape[-1]).to( + device + ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None @@ -642,54 +668,66 @@ def _prepare_decoder_attention_mask( return combined_attention_mask - def _get_eagle_module_inputs( + def _prepare_eagle_inputs( self, input_ids, - eagle_input_hidden_states, attention_mask, position_ids, eagle_cache, + base_outputs, ): """Helper function to prepare eagle inputs for the 0th eagle forward pass.""" - b, seq_length, _ = eagle_input_hidden_states.shape - past_key_values_length = eagle_cache.get_seq_length() if eagle_cache is not None else 0 - seq_length_with_past = seq_length + past_key_values_length + b, seq_length = input_ids.shape + past_kv_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0 + seq_len_with_past = seq_length + past_kv_len - # Prepare eagle_input_ids: Shift left 1 token - zeropadding = torch.zeros( - input_ids.shape[0], 1, dtype=input_ids.dtype, device=input_ids.device - ) - eagle_input_ids = torch.cat((input_ids[:, 1:], zeropadding), dim=1) + # Prepare eagle_input_embeds: Shift left 1 token + with torch.no_grad(): + if base_outputs.input_embeds is None: + eagle_input_embeds = self._base_model_embeddings(input_ids.roll(-1, 1)) + else: + eagle_input_embeds = base_outputs.input_embeds.roll(-1, 1) + + # Prepare eagle_input_hiddens + if self.eagle_config.use_aux_hidden_state: + # Eagle3: concat base model intermediate (pre-norm) hiddens + eagle_input_hiddens = self.eagle_module.fc(base_outputs.aux_hiddens) + else: + # Eagle1: use base model output (post-norm)hiddens + eagle_input_hiddens = base_outputs.out_hiddens # Prepare attention_mask - if attention_mask is not None: # Shift left 1 token for attention_mask - zeropadding = torch.zeros( - attention_mask.shape[0], 1, dtype=attention_mask.dtype, device=attention_mask.device + if attention_mask is None: + eagle_attention_mask = torch.ones( # default: all tokens are valid + (b, seq_len_with_past), dtype=torch.bool, device=eagle_input_hiddens.device ) - attention_mask = torch.cat((attention_mask[:, 1:], zeropadding), dim=1) else: - attention_mask = torch.ones( # Initialize default attention_mask - (b, seq_length_with_past), dtype=torch.bool, device=eagle_input_hidden_states.device - ) - + eagle_attention_mask = attention_mask.roll(-1, 1) # Shift left 1 token # Expand the 2-D attention mask to 4-D and apply causal mask. - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (b, seq_length), eagle_input_hidden_states, past_key_values_length + eagle_attention_mask = self._prepare_decoder_attention_mask( + eagle_attention_mask, + (b, seq_length), + past_kv_len, + eagle_input_hiddens.device, + eagle_input_hiddens.dtype, ) # Prepare position_ids if position_ids is None: - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=eagle_input_hidden_states.device, + eagle_position_ids = ( + torch.arange( + past_kv_len, + seq_len_with_past, + dtype=torch.long, + device=eagle_input_hiddens.device, + ) + .unsqueeze(0) + .view(-1, seq_length) ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: - position_ids = position_ids.view(-1, seq_length).long() + eagle_position_ids = position_ids.view(-1, seq_length).long() - return eagle_input_ids, attention_mask, position_ids + return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids def _compute_ttt_attention_mask( self, batch_size, seq_length, ttt_step @@ -699,7 +737,7 @@ def _compute_ttt_attention_mask( dtypemin = torch.finfo(self._base_llm_config.dtype).min q_len = seq_length kv_len = seq_length * (1 + ttt_step) - if self.eagle_module.config._attn_implementation == "flex_attention": + if self.eagle_config._attn_implementation == "flex_attention": # Return block mask for flex attention block_mask = create_block_mask(msk_func, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len) return block_mask @@ -715,40 +753,10 @@ def _compute_ttt_attention_mask( tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device ).masked_fill(~tensor_mask, dtypemin) + # Note: (hg) repeat mask for kimi-k2 compatibility tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask - def _llm_or_vlm_embedding(self, input_ids, kwargs): - """Return input embeddings with possibly vision embeddings for VLM.""" - tok_embeds = self._base_model_embeddings(input_ids) - - # LLM only have token embeddings - if "pixel_values" not in kwargs: - return tok_embeds - - # Otherwise, insert vision embeddings in tok_embeds - if self.config.model_type == "NemotronH_Nano_VL_V2": - vit_embeds = self.extract_feature(kwargs["pixel_values"]) - vit_embeds = vit_embeds[kwargs["image_flags"] == 1] - bs, seq_len, hid_size = tok_embeds.shape - tok_embeds = tok_embeds.reshape(bs * seq_len, hid_size) - input_ids = input_ids.reshape(bs * seq_len) - selected = input_ids == self.img_context_token_id - try: - tok_embeds[selected] = tok_embeds[selected] * 0.0 + vit_embeds.reshape(-1, hid_size) - except Exception as e: - vit_embeds = vit_embeds.reshape(-1, hid_size) - print( - f"warning: {e}, tok_embeds[selected].shape={tok_embeds[selected].shape}, " - f"vit_embeds.shape={vit_embeds.shape}" - ) - n_token = selected.sum() - tok_embeds[selected] = tok_embeds[selected] * 0.0 + vit_embeds[:n_token] - del vit_embeds - return tok_embeds.reshape(bs, seq_len, hid_size) - else: - raise ValueError(f"VLM model type {self.config.model_type} not supported") - def _base_model_forward( self, input_ids, @@ -769,6 +777,7 @@ def _base_model_forward( **kwargs, ) past_key_values = getattr(outputs, "past_key_values", None) + base_input_embeds = outputs.hidden_states[0] base_model_hidden_states = outputs.hidden_states[-1] base_model_logits = outputs.logits @@ -780,9 +789,16 @@ def _base_model_forward( labels = labels.view(-1) base_model_loss = loss_fct(loss_logits, labels) - return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values + return EagleBaseModelOutput( + input_embeds=base_input_embeds, + aux_hiddens=self.pop_and_gather_aux_hiddens(), + out_hiddens=base_model_hidden_states, + logits=base_model_logits, + loss=base_model_loss, + ), past_key_values def _map_logits_to_draft_vocab(self, full_logits): + assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" reverse_mapping = ( torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device) + self.eagle_module.d2t @@ -839,125 +855,95 @@ def forward( """Forward pass of the EagleModel. Returns: - hidden_states: The hidden state from the base model. - logits: logits from the base model. - eagle_hidden_states: The hidden state from eagle_module. - eagle_logits: logits from the eagle_module. + loss: Loss of base model or eagle model. + logits: Base model logits. + past_key_values: Base model past key values with eagle cache attached. + hidden_states: Base model hidden states. + train_acc: Drafter training accuracies. """ - if past_key_values is not None and hasattr(past_key_values, "eagle_cache"): - eagle_cache = past_key_values.eagle_cache - else: - eagle_cache = None + eagle_cache = getattr(past_key_values, "eagle_cache", None) if self.training: - assert eagle_cache is None, "eagle_cache should be None in training" assert past_key_values is None, "past_key_values should be None in training" if loss_mask is None: - loss_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device) + # By default, mask out padding tokens in loss computation + loss_mask = ( + attention_mask.clone().detach() + if attention_mask is not None + else torch.ones_like(input_ids, dtype=torch.bool) + ) - # ====First, we run base model forward==== - if "base_model_outputs" in kwargs: + # ====First, run base model forward==== + if self.eagle_offline: # Parse base model outputs forwarded from teacher - base_outputs = kwargs["base_model_outputs"] - base_model_hidden_states = base_outputs["base_model_hidden_states"] - if "base_model_logits" in base_outputs: - base_model_logits = base_outputs["base_model_logits"] - else: - base_model_logits = self.lm_head(base_model_hidden_states) - base_model_loss, past_key_values = None, None + assert "base_model_outputs" in kwargs + base_outputs = EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) + if base_outputs.logits is None: + base_outputs.logits = self.lm_head(base_outputs.out_hiddens) + past_key_values = None else: - base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = ( - self._base_model_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - self.eagle_freeze_base_model, - labels, - **kwargs, - ) + base_outputs, past_key_values = self._base_model_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + self.eagle_freeze_base_model, + labels, + **kwargs, ) if not isinstance(past_key_values, Cache): past_key_values = _get_empty_cache(self._base_llm_config) if not isinstance(eagle_cache, Cache): eagle_cache = _get_empty_cache(self.eagle_module.config) + past_key_values.eagle_cache = eagle_cache - # ====Run eagle forward==== + # ====Prepare inputs for the first eagle forward pass==== eagle_loss = None train_accs = [[] for _ in range(self.eagle_config.parallel_draft_step)] - # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers - b, seq_length, h = base_model_hidden_states.shape - if self.eagle_config.use_aux_hidden_state: - if "base_model_outputs" in kwargs: - aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"] - else: - aux_hidden_states = torch.cat(self.pop_and_gather_aux_hiddens(), dim=-1) - eagle_input_hidden_states = self.eagle_module.fc(aux_hidden_states) - else: - eagle_input_hidden_states = base_model_hidden_states - - # Get eagle inputs for the first eagle forward pass - eagle_input_ids, attention_mask_0, position_ids = self._get_eagle_module_inputs( + b, seq_length, _ = base_outputs.out_hiddens.shape + ( + eagle_input_embeds, + eagle_input_hiddens, + eagle_attn_mask_0, + eagle_position_ids, + ) = self._prepare_eagle_inputs( input_ids, - eagle_input_hidden_states, attention_mask, position_ids, eagle_cache, + base_outputs, ) - with torch.no_grad(): - inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs) - - past_key_values.eagle_cache = eagle_cache - # ====Perform training-time-testing with 3 extra eagle forward passes==== + # ====Run eagle forward with extra training-time-test steps==== for ttt_step in range(self.num_ttt_steps): # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then. - attention_mask = ( - attention_mask_0 + eagle_attention_mask = ( + eagle_attn_mask_0 if ttt_step == 0 else self._get_ttt_attention_mask(b, seq_length, ttt_step) ) with enable_cp_ttt_patch() if self.training else contextlib.nullcontext(): - _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward( - eagle_input_hidden_states, - inputs_embeds, - attention_mask, - position_ids, + _, eagle_input_hiddens, eagle_logits, eagle_cache = self._eagle_forward( + eagle_input_hiddens, + eagle_input_embeds, + eagle_attention_mask, + eagle_position_ids, eagle_cache, ) - eagle_input_hidden_states = torch.cat( - ( - torch.zeros( - (b, 1, h), - dtype=eagle_input_hidden_states.dtype, - device=eagle_input_hidden_states.device, - ), - eagle_input_hidden_states[:, :-1, :], - ), - dim=1, - ) + eagle_input_hiddens = eagle_input_hiddens.roll(1, 1) for i in range(self.eagle_config.parallel_draft_step): eagle_logit = eagle_logits[i] classification_loss, acc = self._eagle_loss( # base model predict +1 tok, while eagle predict +2 # so we shift base model outputs compared to eagle outputs - base_model_logits[:, 1 + i :], - eagle_logit[:, : -(1 + i)], # additionally, we mask the first n tok of eagle outputs at nth TTT step - torch.cat( - ( - torch.zeros( - b, ttt_step, dtype=loss_mask.dtype, device=loss_mask.device - ), - loss_mask[:, 1 + ttt_step :] - if i == 0 - else loss_mask[:, 1 + ttt_step : -i], - ), - dim=1, - ), + base_outputs.logits[:, 1 + i + ttt_step :], + eagle_logit[:, ttt_step : -(1 + i)], + loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i], ) + # Apply loss decay factor to focus on early steps classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i) eagle_loss = ( classification_loss if eagle_loss is None else eagle_loss + classification_loss @@ -965,24 +951,19 @@ def forward( train_accs[i].append(acc) if not self.training: break - # Finally, we merge base model loss and eagle loss, raise error if both are None - if base_model_loss is not None and eagle_loss is not None: - loss = base_model_loss + eagle_loss - elif base_model_loss is not None: - loss = base_model_loss - elif eagle_loss is not None: - loss = eagle_loss - else: + + # Merge base model loss and eagle loss + if base_outputs.loss is None and eagle_loss is None: loss = None - assert not self.training, ValueError( - "Both base_model_loss and eagle_loss are skipped. At least one loss must be computed." - ) + assert not self.training, "At least one loss must be computed for training." + else: + loss = (base_outputs.loss or 0) + (eagle_loss or 0) return ModelOutput( loss=loss, - logits=base_model_logits, + logits=base_outputs.logits, past_key_values=past_key_values, - hidden_states=base_model_hidden_states, + hidden_states=base_outputs.out_hiddens, train_acc=train_accs, ) @@ -994,10 +975,8 @@ def _eagle_loss( ): """Function for EAGLE loss computing.""" if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) - loss_mask = loss_mask[:, :, None] - loss_mask = loss_mask[:, : eagle_logits.shape[1]] + loss_mask = loss_mask[:, : eagle_logits.shape[1], None] classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)( eagle_logits ) @@ -1047,34 +1026,28 @@ def pseudo_speculative_generate( # EAGLE-3 # Only the first iteration input_hidden_states are from aux_hidden_state layers # Gather _aux_hidden_states from all devices before concatenation - gathered_aux_hidden_states = self.pop_and_gather_aux_hiddens() - eagle_input_hidden_states = self.eagle_module.fc( - torch.cat(gathered_aux_hidden_states, dim=-1) - ) - + eagle_input_hidden_states = self.eagle_module.fc(self.pop_and_gather_aux_hiddens()) else: eagle_input_hidden_states = base_model_hidden_states draft_tokens = [] for step in range(steps): - # Get eagle inputs for the first eagle forward pass - _, eagle_attention_mask, eagle_position_ids = self._get_eagle_module_inputs( - input_ids, - eagle_input_hidden_states, - None, - None, + b, seq_length = eagle_ids.shape + eagle_attention_mask = self._prepare_decoder_attention_mask( None, + (b, seq_length), + 0, + eagle_input_hidden_states.device, + eagle_input_hidden_states.dtype, ) # Use SDPA attention during generation for both stability and performance - with temporary_set_config_value( - self.eagle_module.config, "_attn_implementation", "sdpa" - ): + with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"): _, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward( eagle_input_hidden_states, self._base_model_embeddings(eagle_ids), eagle_attention_mask, - eagle_position_ids, + None, ) # parallel logits are only used after the last step diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index d259a1fce..e067641ed 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -25,6 +25,7 @@ import torch import torch.distributed +import transformers from huggingface_hub import snapshot_download from torch import nn from torch.nn.attention import SDPBackend, sdpa_kernel @@ -42,6 +43,9 @@ def calibrate_frequent_vocab(tokenizer, text, target_vocab_size, output_file=None): """Given a calibration text, find the most common vocabs and return the mapping.""" conversations = tokenizer.apply_chat_template(text) + # Transformers5.x returns a BatchEncoding from apply_chat_template + if hasattr(conversations, "input_ids"): + conversations = conversations.input_ids counter = Counter(conversations) vocab = counter.most_common(target_vocab_size) mapping = torch.zeros(target_vocab_size, dtype=torch.int64) @@ -468,3 +472,16 @@ def enable_cp_ttt_patch(): yield finally: modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False + + +def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs): + """Load a VLM or LLM with kwargs. Returns the model and model config.""" + model_config = transformers.AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + if "vl" in model_config.model_type.lower(): + model_cls = transformers.AutoModelForVision2Seq + else: + model_cls = transformers.AutoModelForCausalLM + + return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py new file mode 100644 index 000000000..e147ebf2c --- /dev/null +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Processing large data to tokenize for pretraining.""" + +import copy +import itertools +import os + +import torch +import transformers +from datasets import load_dataset +from transformers.trainer_pt_utils import LabelSmoother + +from modelopt.torch.utils import print_rank_0 + +REMOVE_THINK_CHAT_TEMPLATE = ( + "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}" +) + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +def _sharegpt_to_openai_messages(conversations: list[dict]): + """Optionally align sharedgpt format to openai format.""" + role_mapping = { + "user": "user", + "User": "user", + "human": "user", + "assistant": "assistant", + "Assistant": "assistant", + "gpt": "assistant", + "system": "system", + "System": "system", + } + messages = [] + for msg in conversations: + role = role_mapping[msg["role"]] + content = msg["content"] + messages.append({"role": role, "content": content}) + return messages + + +class ShardedDataset(torch.utils.data.Dataset): + """Subclass of torch.utils.data.Dataset to load data from HuggingFace dataset.""" + + def __init__( + self, + name: str, + subset: str | None = None, + data_files: str | None = None, + split: str = "train", + num_shards: int = 1, + shard_index: int = 0, + num_streaming_samples: int | None = None, + ): + """Initialize the ShardedDataset.""" + self.name = name + self.subset = subset + self.split = split + self.data_files = data_files + self.num_shards = num_shards + self.shard_index = shard_index + self.num_streaming_samples = num_streaming_samples + + self._load_dataset() + + def __len__(self): + if self.num_streaming_samples is not None: + return self.num_streaming_samples + else: + return len(self._raw_samples) + + def __getitem__(self, index): + index = index // self.num_shards + + if self.num_streaming_samples is not None: + while index >= len(self._raw_samples): + self._raw_samples.append(next(self._stream_iterator)) + + return self._raw_samples[index] + + def _load_dataset(self): + dataset = load_dataset( + self.name, + self.subset, + data_files=self.data_files, + split=self.split, + # num_proc=4, # TODO: Make this configurable + streaming=self.num_streaming_samples is not None, + ) + + shard = dataset.shard(num_shards=self.num_shards, index=self.shard_index) + + if self.num_streaming_samples is not None: + self._raw_samples = [] + self._stream_samples = shard + self._stream_iterator = itertools.cycle(self._stream_samples) + else: + self._raw_samples = shard + + +class LanguageDataCollator: + """Data collator for language modeling tasks. + + Accepts samples in OpenAI or ShareGPT formats and returns + tokenized outputs with padding and truncation, including + input_ids and attention_mask. + """ + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizerBase, + train_len: int = 4096, + chat_template: str | None = None, + add_generation_prompt: bool = False, + answer_only_loss: bool = False, + json_key: str = "text", + return_labels: bool = False, + ): + """Initialize the LanguageDataset.""" + if not isinstance(tokenizer, transformers.PreTrainedTokenizerBase): + raise ValueError( + "The tokenizer must be a transformers.PreTrainedTokenizerBase but got {}".format( + type(tokenizer) + ) + ) + self.tokenizer = tokenizer + self.train_len = train_len + self.add_generation_prompt = add_generation_prompt + self.answer_only_loss = answer_only_loss + self.json_key = json_key + self.return_labels = return_labels + + if chat_template is not None: + self.tokenizer.chat_template = chat_template + else: + self._post_process_chat_template() + + self._post_process_tokenizer() + if self.tokenizer.chat_template is None: + raise ValueError("No valid chat template!") + + def _post_process_tokenizer(self): + if self.tokenizer.pad_token_id is None: + print_rank_0("The tokenizer has no pad_token_id, using eos_token_id instead.") + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + if hasattr(self.tokenizer, "pad_token") and self.tokenizer.pad_token is None: + if self.tokenizer.eos_token == "<|eot_id|>": # nosec + self.tokenizer.pad_token = "<|end_of_text|>" # nosec + else: + raise ValueError("The tokenizer has no pad_token!") + + def _post_process_chat_template(self): + # [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the + # tokens are preserved for supervised learning. + self.tokenizer.chat_template = self.tokenizer.chat_template.replace( + REMOVE_THINK_CHAT_TEMPLATE, "" + ) + + def _process_chat_sample(self, examples: list): + tokenized_examples = self.tokenizer.apply_chat_template( + examples, + return_tensors="pt", + return_dict=True, + padding="max_length", + truncation=True, + max_length=self.train_len, + add_generation_prompt=self.add_generation_prompt, + return_assistant_tokens_mask=self.answer_only_loss, + ) + if self.return_labels: + input_ids = tokenized_examples["input_ids"] + labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) + labels[..., :-1] = input_ids[..., 1:] + tokenized_examples["labels"] = labels + return tokenized_examples + + def _process_text_sample(self, examples: list): + tokenized_examples = self.tokenizer( + examples, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.train_len, + ) + return tokenized_examples + + def __call__(self, examples): + """Call the LanguageDataCollator.""" + batch = [] + + for example in examples: + if not isinstance(example, dict): + raise ValueError("The sample must be a Dict but got {}".format(type(example))) + text = example.get(self.json_key, None) + if isinstance(text, str): + batch.append(text) + else: + messages = example.get("messages", None) + if messages is None: + conversations = example.get("conversations", None) + if conversations is None: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) + else: + messages = _sharegpt_to_openai_messages(conversations) + batch.append(messages) + + return self._process_chat_sample(batch) + + +class VisionLanguageDataCollator(LanguageDataCollator): + """VisionLanguageDataCollator is a subclass of LanguageDataCollator that is used to collate vision-language data.""" + + def __init__( + self, + processor: str, + train_len: int = 8192, + chat_template: str | None = None, + add_generation_prompt: bool = False, + answer_only_loss: bool = False, + local_image_path: str = "", + return_labels: bool = False, + ): + """Initialize the VisionLanguageDataset.""" + self.processor = transformers.AutoProcessor.from_pretrained(processor) + self.chat_template = chat_template + self.local_image_path = local_image_path + + super().__init__( + tokenizer=self.processor.tokenizer, + train_len=train_len, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + answer_only_loss=answer_only_loss, + return_labels=return_labels, + ) + + def _process_multimodal_sample(self, examples): + tokenized_messages = self.processor.apply_chat_template( + examples, + tokenize=True, + return_tensors="pt", + return_dict=True, + padding="max_length", + truncation=True, + max_length=self.train_len, + add_generation_prompt=self.add_generation_prompt, + return_assistant_tokens_mask=self.answer_only_loss, + ) + + return tokenized_messages + + def __call__(self, examples): + """Call the VisionLanguageDataCollator.""" + batch = [] + + for example in examples: + messages = example.get("messages", None) + if messages is None: + conversations = example.get("conversations", None) + if conversations is None: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) + else: + messages = _sharegpt_to_openai_messages(conversations) + + copy_messages = copy.deepcopy(messages) + + for msg in copy_messages: + if isinstance(msg["content"], str): + msg["content"] = [{"type": "text", "text": msg["content"]}] + + for ctn in msg["content"]: + if ctn["type"] == "image" and "image" in ctn: + ctn["image"] = os.path.abspath( + os.path.join(self.local_image_path, ctn["image"]) + ) + # If any value in ctn is None, delete that key + # HF dataloader add Nones to align keys. Leads to error in processor. + keys_to_delete = [k for k, v in ctn.items() if v is None] + for k in keys_to_delete: + del ctn[k] + + batch.append(copy_messages) + + return self._process_multimodal_sample(batch) diff --git a/tests/examples/speculative_decoding/conftest.py b/tests/examples/speculative_decoding/conftest.py index bc75b8783..80417f404 100644 --- a/tests/examples/speculative_decoding/conftest.py +++ b/tests/examples/speculative_decoding/conftest.py @@ -21,18 +21,20 @@ @pytest.fixture(scope="session", autouse=True) def tiny_daring_anteater_path(tmp_path_factory): - dataset_path = MODELOPT_ROOT / "examples/speculative_decoding/Daring-Anteater" + dataset_path = ( + MODELOPT_ROOT / "examples/speculative_decoding/input_conversations/daring-anteater.jsonl" + ) if not os.path.exists(dataset_path): try: run_example_command( - ["git", "clone", "https://huggingface.co/datasets/nvidia/Daring-Anteater"], + ["python", "prepare_input_conversations/add_daring_anteater.py"], "speculative_decoding", ) except Exception as e: # Ignore rate-limiting errors - pytest.skip(f"Failed to clone Daring-Anteater dataset: {e}") + pytest.skip(f"Failed to prepare dataset: {e}") output_path = tmp_path_factory.mktemp("daring_anteater") / "train.jsonl" - with open(dataset_path / "train.jsonl") as src, open(output_path, "w") as dst: + with open(dataset_path) as src, open(output_path, "w") as dst: for i, line in enumerate(src): if i >= 128: break diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 3775b8a4c..3cbbc69c8 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -14,6 +14,7 @@ # limitations under the License. import json +import os import pytest import safetensors.torch @@ -30,6 +31,35 @@ def eagle_output_dir(tmp_path_factory): return tmp_path_factory.mktemp("eagle_output_dir") +@pytest.fixture(scope="module") +def draft_vocab_cache_dir(tmp_path_factory): + """Eagle output directory shared in this module.""" + return tmp_path_factory.mktemp("eagle_output_dir") + + +def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft_vocab_cache_dir): + """Test calibration of draft vocabulary.""" + run_example_command( + [ + "python", + "./scripts/calibrate_draft_vocab.py", + "--model", + tiny_llama_path, + "--data", + tiny_daring_anteater_path, + "--draft_vocab_size", + "100", + "--save_dir", + draft_vocab_cache_dir, + ], + "speculative_decoding", + ) + + model_name = os.path.basename(os.path.normpath(tiny_llama_path)) + d2t = torch.load(os.path.join(draft_vocab_cache_dir, model_name, "d2t.pt")) + assert d2t.shape[0] == 100, f"Expected draft vocab size 100, got {d2t.shape[0]}" + + # fmt: off @pytest.mark.parametrize("cp_size", [1, 2]) def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, cp_size): @@ -77,8 +107,8 @@ def test_ar_validate(eagle_output_dir): [ "python", "./scripts/ar_validate.py", "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", - "--osl", "20", - "--num_samples", "10", + "--osl", "10", + "--num_samples", "5", "--steps", "3" ], "speculative_decoding", @@ -112,17 +142,3 @@ def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir): ], "speculative_decoding", ) - -@pytest.mark.skip(reason="Needs dataset conversion to role-content format; consolidate data loading first.") -def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path,tmp_path): - """Test calibration of draft vocabulary.""" - run_example_command( - [ - "python", "./scripts/calibrate_draft_vocab.py", - "--model", tiny_llama_path, - "--data", tiny_daring_anteater_path, - "--draft_vocab_size", "100", - "--save_dir", tmp_path / "draft_vocab_cache", - ], - "speculative_decoding", - )