-
Notifications
You must be signed in to change notification settings - Fork 266
Add Dynamic Memory Sparsification (DMS) training and inference implementation #877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughIntroduces a comprehensive Dynamic Memory Sparsification (DMS) system for Qwen3 models. Adds documentation, core DMS modules (attention, caching, training logic), configuration files, data pipeline, student-teacher distillation training infrastructure, Qwen3 model adaptations, training/evaluation scripts, and test suite. Changes
Sequence Diagram(s)sequenceDiagram
participant Student as Student Model
participant Dist as Distillation<br/>Loss
participant Teacher as Teacher Model
participant LM as LM Loss
participant DMS as DMS Loss
Student->>+Student: Forward Pass<br/>(input_ids)
Student->>Student: Apply DMS<br/>Gating & Attention
Student-->>-Student: Logits + State
Student->>+Dist: Student Logits
Teacher->>+Teacher: Forward Pass<br/>(input_ids, eval mode)
Teacher-->>-Teacher: Teacher Logits
Teacher->>Dist: Teacher Logits
Dist->>Dist: KL Divergence<br/>Loss
Dist-->>-Student: Distillation Loss
Student->>+LM: Student Logits +<br/>Labels
LM->>LM: Cross Entropy
LM-->>-Student: LM Loss
Student->>+DMS: DMS State<br/>(frac_closed, decisions)
DMS->>DMS: Compute DMS<br/>Regularization
DMS-->>-Student: DMS Loss
Student->>Student: Combine Losses<br/>(α×Dist + β×LM + γ×DMS)
Student->>Student: Backward Pass &<br/>Optimize
sequenceDiagram
participant Input as Input Sequence
participant Prefill as Chunked Prefill
participant ChunkAtt as Chunk Attention<br/>Processing
participant Cache as DMS Cache<br/>(Contiguous +<br/>Paged)
participant Inference as Token Generation<br/>Loop
Input->>+Prefill: Long Sequence
Prefill->>Prefill: Split into<br/>Chunks
loop For Each Chunk
Prefill->>+ChunkAtt: Chunk Tokens +<br/>Position Info
ChunkAtt->>ChunkAtt: Compute DMS<br/>Decisions & Gating
ChunkAtt->>ChunkAtt: FlexAttention with<br/>Eviction Masking
ChunkAtt-->>-Prefill: Chunk Output
Prefill->>+Cache: Update Recent &<br/>Paged KV
Cache->>Cache: Evict old tokens<br/>per decisions
Cache-->>-Prefill: Cache State
end
Prefill-->>-Inference: Final Hidden State +<br/>Populated Cache
Inference->>+Inference: Generate Loop
loop For Each Token
Inference->>+ChunkAtt: Query Token +<br/>Cached KV
ChunkAtt->>ChunkAtt: FlashAttention with<br/>KV Cache & Paging
ChunkAtt-->>-Inference: Attention Output
Inference->>Inference: Compute Logits &<br/>Sample Token
Inference->>+Cache: Append New Token<br/>to Cache
Cache->>Cache: Update Cache with<br/>Eviction Check
Cache-->>-Inference: Updated Cache
end
Inference-->>-Inference: Generated Sequence
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
🤖 Fix all issues with AI agents
In `@experimental/dms/dms/attention.py`:
- Around line 146-157: In score_mod, the sliding-window check uses "<=" which
allows window_size+1 tokens during training; change the comparison in
within_sliding_window from "q_idx - k_idx <= window_size" to "q_idx - k_idx <
window_size" so training matches eval/test behavior; update the expression in
the score_mod function (using variables window_size, within_sliding_window,
causal, dms_mask_values) to use the strict "<" boundary.
In `@experimental/dms/dms/core.py`:
- Around line 287-298: The function dms_chunked_prefill defines a parameter
named dms_chunked_prefill which shadows the function name; rename the parameter
to dms_chunk_size and update every use inside dms_chunked_prefill (references to
the parameter) to dms_chunk_size, and then update all call sites that pass the
old parameter name to use dms_chunk_size instead (ensure signatures and kwargs
are consistent). Also check any type hints/annotations and docstrings inside
dms_chunked_prefill referencing dms_chunked_prefill and update them to
dms_chunk_size to keep names consistent.
In `@experimental/dms/dms/training/data.py`:
- Around line 445-459: The current __getitem__ mutates self.dataset_iterators
(and uses sample_mapping) which breaks when DataLoader spawns workers; instead
make per-index selection deterministic by removing mutation and computing or
looking up the per-dataset sample index from global index: precompute a static
mapping (e.g. self.sample_index_mapping) in __init__ that for each global index
yields the corresponding (ds_idx, ds_sample_idx), then change __getitem__ to use
ds_idx = self.sample_mapping[index] and ds_sample_idx =
self.sample_index_mapping[index] (or compute ds_sample_idx = index %
len(self.datasets[ds_idx]) if that matches desired behavior) and do not modify
dataset_iterators; update or remove dataset_iterators usage in __getitem__
accordingly.
In `@experimental/dms/dms/training/engine.py`:
- Around line 222-246: dms_schedule can divide by zero when max_steps is 0; add
a guard in the function (dms_schedule) before computing progress: determine
max_steps as currently done, then if max_steps <= 0 set progress = 1.0 (or
progress = 0.0 if you prefer saturating to start) otherwise compute progress =
min(step / max_steps, 1.0); keep the rest of the logic (cr, frac, target)
unchanged. This prevents ZeroDivisionError when dms_final_step and
training_args.max_steps are zero.
- Around line 594-595: The constructor/function currently uses mutable defaults
forward_fn_kwargs_student: dict[str, Any] = {} and forward_fn_kwargs_teacher:
dict[str, Any] = {}, causing shared-state bugs; change both default values to
None (forward_fn_kwargs_student: Optional[dict[str, Any]] = None,
forward_fn_kwargs_teacher: Optional[dict[str, Any]] = None) and inside the
constructor/body of the class/function (where these params are assigned to
self.forward_fn_kwargs_student / self.forward_fn_kwargs_teacher) initialize them
with an explicit None check (e.g., if forward_fn_kwargs_student is None:
self.forward_fn_kwargs_student = {} else: self.forward_fn_kwargs_student =
forward_fn_kwargs_student) and similarly for forward_fn_kwargs_teacher so each
instance gets its own dict.
In `@experimental/dms/models/qwen3/configuration_qwen3_dms.py`:
- Around line 96-97: The constructor currently calls setup_compile_limit_for_dms
when dms_compile_limit is provided, which mutates the global
torch._dynamo.config.cache_size_limit during config instantiation (including
from_pretrained). Remove the side-effect from __init__ by deferring
setup_compile_limit_for_dms to model initialization (e.g., a model factory or an
explicit initialize method) or add clear documentation that __init__ mutates
torch._dynamo.config; modify the code path that currently invokes
setup_compile_limit_for_dms in the config class (the __init__ handling of
dms_compile_limit) and instead call setup_compile_limit_for_dms from the model
loading/initialization routine (the code that constructs or loads the model
after from_pretrained) so deserialization no longer performs global mutation
unless initialization is explicitly performed.
In `@experimental/dms/models/qwen3/extract.py`:
- Around line 95-96: The code currently calls torch.load(model_path,
weights_only=True) in the combined_model.load_state_dict path which will fail if
model_path points to a .safetensors file; update the loader to check the
extension of model_path and use safetensors.torch.load_file(model_path) for
.safetensors and torch.load otherwise (preserving the weights_only behavior
where applicable), then pass the resulting state dict to
combined_model.load_state_dict; refer to logger.info,
combined_model.load_state_dict, torch.load, safetensors.torch.load_file, and
model_path when making the change.
- Around line 62-64: The code currently hardcodes checkpoint_dir /
"pytorch_model.bin" (model_path) which will fail for safetensors, sharded
checkpoints, or other naming; update the logic that builds model_path (using
cli_args.checkpoint and checkpoint_dir) to detect available checkpoint files
instead of a fixed filename: check for model.safetensors, any
pytorch_model-*-of-*.bin shards, and model.*.bin in checkpoint_dir (or use a
HF/transformers utility to resolve the correct weights file), pick and validate
the first matching path, and keep save_path as before; ensure the detection
happens where model_path is assigned so functions referencing model_path use the
resolved file.
In `@experimental/dms/models/qwen3/train.py`:
- Around line 216-253: The code currently calls student_model.to(torch.bfloat16)
which mutates combined_model.student_model in-place; change
extract_student_model to avoid mutating the in-memory model by creating a deep
copy, e.g. use copy.deepcopy to make student_copy =
copy.deepcopy(student_model), call student_copy.to(torch.bfloat16) and
save_pretrained on that copy (and leave combined_model.student_model untouched);
add the necessary import for copy or, if you prefer, explicitly document that
extract_student_model is terminal and must not be called before any further
evaluation.
- Around line 82-97: resolve_checkpoint returns None when no checkpoints are
found but leaves cfg["hf_trainer"]["resume_from_checkpoint"] set to "auto",
which later gets passed to TrainingArguments(**hf_trainer_cfg); change the code
that calls resolve_checkpoint (the place handling checkpoint_path/line ~297) so
that when resolve_checkpoint returns a falsy value you explicitly set
cfg["hf_trainer"]["resume_from_checkpoint"] = None (or delete that key) before
constructing TrainingArguments(**hf_trainer_cfg) — ensure the symbol names
involved are resolve_checkpoint, cfg["hf_trainer"], resume_from_checkpoint and
TrainingArguments(**hf_trainer_cfg).
In `@experimental/dms/pyproject.toml`:
- Line 8: The requires-python spec in pyproject.toml currently uses
">=3.10,<=3.13" which excludes patch releases (e.g., 3.13.1); update the
requires-python constraint to ">=3.10,<3.14" so patch releases within 3.13 are
allowed (locate and edit the requires-python entry in
experimental/dms/pyproject.toml).
In `@experimental/dms/scripts/evaluate.sh`:
- Line 3: Update the SPDX copyright header in
experimental/dms/scripts/evaluate.sh (the SPDX-FileCopyrightText line) to
reflect the correct year or range for a new 2026 file—e.g., change "2024" to
"2026" or "2024-2026"—so the header is current and follows project convention.
- Around line 29-32: Add strict bash failure flags to the top of the script so
errors (including pipeline failures) cause immediate exit: replace or augment
the existing "set -x" with "set -euo pipefail" (and keep -x if debugging is
desired) in experimental/dms/scripts/evaluate.sh so that failures from commands
like accelerate launch do not return a zero exit code; ensure MODEL_PATH
handling (variable MODEL_PATH and the usage check) remains unchanged.
In `@experimental/dms/scripts/train_small.sh`:
- Line 28: The exported env var uses the wrong name: replace PYTORCH_ALLOC_CONF
with PYTORCH_CUDA_ALLOC_CONF in the export line so the expandable_segments:True
setting actually takes effect on CUDA; update the export in train_small.sh (the
line exporting PYTORCH_ALLOC_CONF) to export
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to match the usage in train.sh
and prevent silent non-application of the memory optimization.
🧹 Nitpick comments (26)
experimental/dms/pyproject.toml (2)
20-26: Consider narrowing theDeprecationWarningfilter.Blanket
ignore::DeprecationWarningcan mask actionable deprecation signals from your own code or key dependencies. Consider scoping it to specific noisy modules (e.g.,ignore::DeprecationWarning:transformers). Not urgent for experimental code, but worth keeping in mind.
10-15: Pinlm_evalto a specific version for consistency.The other dependencies are pinned to exact versions for reproducibility (
transformers==4.57.3,datasets==4.4.2,accelerate==1.4.0), butlm_eval[ruler]is unpinned. Since multiple versions oflm_evalexist (0.4.2–0.4.10), a breaking update could silently break evaluation. Pin it to match the pattern of the other dependencies.experimental/dms/dms/logging.py (1)
22-35: Setlogger.propagate = Falseto prevent duplicate log messages.In a training environment where HuggingFace Trainer or
acceleratetypically configure the root logger, messages from DMS loggers will propagate upward and be printed twice (once by this handler, once by the root). Settingpropagate = Falseensures only the DMS formatter is used.Proposed fix
if not logger.handlers: logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stderr) handler.setFormatter( logging.Formatter("[%(asctime)s] DMS %(name)s [%(levelname)s]: %(message)s") ) logger.addHandler(handler) + logger.propagate = False return loggerexperimental/dms/tests/utils.py (1)
25-27:sys.pathmanipulation is fragile; consider relying on the editable install instead.Since the README instructs users to run
pip install -e ., the DMS package should already be importable without path hacking. If this is kept as a fallback, consider guarding against duplicate appends and usingsys.path.insert(0, ...)to ensure the local version takes priority over a stale installed copy.experimental/dms/dms/training/data.py (3)
429-440: Weighted sample counts may silently drop samples due toint()truncation.
int(nw * train_samples)truncates rather than rounds, so the total number of entries insample_mappingcan be less thantrain_samples. For three equally-weighted datasets withtrain_samples=4000, each getsint(0.333… × 4000) = 1333, totaling 3999 instead of 4000.Consider distributing the remainder to avoid the shortfall:
♻️ Proposed fix
- self.sample_mapping = [] - for i, nw in enumerate(self.normalized_weights): - nw = nw.item() - self.sample_mapping.append(np.full(int(nw * train_samples), i)) + # Distribute samples proportionally, then assign remainder to largest-weight datasets + raw_counts = self.normalized_weights * train_samples + counts = raw_counts.astype(int) + remainder = train_samples - counts.sum() + # Assign remaining samples by largest fractional part + fractions = raw_counts - counts + for idx in np.argsort(-fractions)[:remainder]: + counts[idx] += 1 + + self.sample_mapping = [] + for i, count in enumerate(counts): + self.sample_mapping.append(np.full(count, i))
478-497:extract_fnalways picks the last correct solution rather than the first.The loop at lines 482–484 iterates through all
(generation, correctness)pairs without breaking, sosolutionends up as the last verified-correct answer. If the intent is to use any correct solution, breaking on the first match is more efficient and deterministic.♻️ Proposed fix: break on first match
for gen, correctness in zip(ds_elem["generations"], ds_elem["correctness_math_verify"]): if correctness: solution = gen + break
286-303: Removeencoded_prompt_untrimmedor document its use.Line 289 stores the full token list in
encoded_prompt_untrimmed, which persists in the cached dataset. However, this field is never referenced anywhere else in the codebase. If nothing downstream consumes it, this wastes disk space and memory—especially for large datasets with long sequences. Either remove this field or add a comment explaining why it's retained.experimental/dms/scripts/train_small.sh (1)
39-41: No separate dataset preparation step, unliketrain.sh.
train.shruns--prepare-dataset-onlybefore training to avoid potential issues with dataset processing during distributed training. This script skips that step. For a single-GPU debug script this is unlikely to cause problems, but for consistency you may want to add it — especially since the first run will block on dataset processing before training begins (no progress feedback).experimental/dms/tests/test_paged_cache.py (1)
255-403: Consider guarding against potential edge case withdms_window_sizerange.Lines 267–269: when
block_sizeis at its maximum (15, fromrandint(1, 16)), thedms_window_sizerange becomesrandint(16, 17), always yielding 16. This works but meansdms_window_size = block_size + 1is the only value tested for largeblock_size. Not a bug, but you may want to widenmax_valor decouple limits to get broader coverage of thedms_window_size > block_sizeinvariant.experimental/dms/example_inference.ipynb (1)
57-81:clean_cache()placement — model memory not freed afterexample_prefill_generate.
clean_cache()is called beforeexample_prefill_generate, but the model and tensors created inside the function are not explicitly deleted afterward. Since the notebook's next cell also starts withclean_cache(), this works in practice, but callingclean_cache()after each example (or using a context manager /del model) would be more robust for users running cells out of order.experimental/dms/models/qwen3/configuration_qwen3_dms.py (1)
86-97: UseValueErrorinstead ofassertfor input validation.These assertions validate user-provided configuration values but can be silently disabled with
python -O. For a configuration class that guards runtime invariants, raisingValueErroris more robust.Proposed fix
- assert self.dms_paged_attention_block_size > 0, ( - f"dms_paged_attention_block_size: {self.dms_paged_attention_block_size} is not greater than 0" - ) - assert self.dms_window_size > self.dms_paged_attention_block_size, ( - f"dms_window_size: {self.dms_window_size} " - f"is not greater than dms_paged_attention_block_size: {self.dms_paged_attention_block_size}" - ) - assert self.dms_alpha_per in ["head", "layer"], ( - f"dms_alpha_per: {self.dms_alpha_per} is not supported" - ) + if self.dms_paged_attention_block_size <= 0: + raise ValueError( + f"dms_paged_attention_block_size must be > 0, got {self.dms_paged_attention_block_size}" + ) + if self.dms_window_size <= self.dms_paged_attention_block_size: + raise ValueError( + f"dms_window_size ({self.dms_window_size}) must be > " + f"dms_paged_attention_block_size ({self.dms_paged_attention_block_size})" + ) + if self.dms_alpha_per not in ("head", "layer"): + raise ValueError(f"dms_alpha_per must be 'head' or 'layer', got '{self.dms_alpha_per}'")experimental/dms/tests/test_prefill_and_generate.py (1)
198-219: Inconsistent.cuda()usage fortorch.randintin parameter generation.Lines 202–208 call
.cuda().item()on random integers, while line 201 uses.item()directly. Sincetorch.randintwith.item()returns a Python int regardless of device, the.cuda()calls are unnecessary overhead.Proposed cleanup
torch.manual_seed(seed) batch = torch.randint(1, 5, (1,)).item() - heads_kv = torch.randint(1, 5, (1,)).cuda().item() - gqa_factor = torch.randint(1, 4, (1,)).cuda().item() - seq_len = torch.randint(8, 1024, (1,)).cuda().item() + heads_kv = torch.randint(1, 5, (1,)).item() + gqa_factor = torch.randint(1, 4, (1,)).item() + seq_len = torch.randint(8, 1024, (1,)).item() head_dim = 3 - chunk_size = torch.randint(1, 128, (1,)).cuda().item() - dms_block_size = torch.randint(2, 32, (1,)).cuda().item() - dms_window_size = torch.randint(dms_block_size + 1, 128, (1,)).cuda().item() + chunk_size = torch.randint(1, 128, (1,)).item() + dms_block_size = torch.randint(2, 32, (1,)).item() + dms_window_size = torch.randint(dms_block_size + 1, 128, (1,)).item()experimental/dms/models/qwen3/train.py (2)
105-123:_parse_data_blend_elements— no error handling for malformed blend strings.Line 118
entry.split(":")will produce unexpected results if an entry doesn't contain exactly one":". Consider usingsplit(":", 1)and validating the result to provide a clearer error message than an index error.
61-61: Mutable default argument{}ontrain_attn_kwargs.Line 61 in
dms_attentionand line 114 indms_attn_train_mode(inattention.py) usedictliterals as default arguments. While these specific callsites don't mutate the dict, the pattern is a well-known Python pitfall. Consider usingNonewith anor {}fallback.Also applies to: 114-114
experimental/dms/dms/attention.py (1)
32-42: Guarded imports default toNone, but are used as default arguments on lines 190–191.If
flash_attnordms.attention_prefillfails to import,flash_attn_with_kvcacheanddms_run_prefill_flexbecomeNone. These are then used as default parameter values fordms_attn_eval_mode, leading to a confusingTypeError: 'NoneType' object is not callableat runtime instead of a clear import error.Consider raising a descriptive error when the defaults are
Noneand the function is called:Proposed defensive check
def dms_attn_eval_mode( ... flash_attn_fn: Callable = flash_attn_with_kvcache, prefill_attn_fn: Callable = dms_run_prefill_flex, prefill_attn_fn_kwargs: dict = {}, ): """Perform DMS attention in evaluation mode using flash attention or flex prefill.""" + if flash_attn_fn is None: + raise ImportError("flash_attn_with_kvcache is required for eval mode but failed to import") + if prefill_attn_fn is None: + raise ImportError("dms_run_prefill_flex is required for prefill mode but failed to import") assert decisions.dtype in (torch.int32, torch.long), (experimental/dms/models/qwen3/modeling_qwen3_dms.py (3)
96-158:position_idsanduse_cachesilently absorbed by**kwargs.
Qwen3DecoderLayerDMS.forwardpassesposition_idsanduse_cachetoself.self_attn(...), butQwen3AttentionDMS.forwarddoesn't declare these parameters — they silently end up in**kwargs. This works but makes the interface misleading. Consider either accepting them explicitly (and ignoring) or not passing them.
406-415: Silently replacing non-DMSCachepast_key_valuesmay mask integration bugs.When
past_key_valuesis notNoneand not aDMSCache, the code logs a warning and replaces it with a fresh empty cache (Line 415), discarding whatever state was passed. This silent replacement during generation could hide bugs in callers that pass an incompatible cache type. Consider raising an error instead of silently replacing, or at least making this alogger.error.
337-350: Redundant sliceself.layers[: self.config.num_hidden_layers].
self.layersis constructed at Line 259–263 with exactlyconfig.num_hidden_layerselements, so slicing to[:self.config.num_hidden_layers]is a no-op.experimental/dms/dms/core.py (1)
119-121:@torch.compile()on a function receivingnn.Moduleparameters may cause excessive recompilations.
prepare_attention_inputreceives multiplenn.Moduleinstances (projections, norms) as arguments. Each unique module instance (i.e., from each layer) will likely trigger a separate compile guard, leading to one compilation per layer. This is presumably the reason forsetup_compile_limit_for_dms(compile_limit=72), but it's worth documenting this relationship explicitly.experimental/dms/dms/attention_prefill.py (3)
393-406: LSE-based attention merging could overflow for very long sequences.
denom_local = torch.exp(softmax_lse_local.float())anddenom_paged = torch.exp(softmax_lse_paged.float())compute raw sum-of-exponentials. For extremely long sequences with high attention scores,exp(lse)can overflow float32. The numerically stable approach is to subtract the max LSE before exponentiating:max_lse = torch.maximum(softmax_lse_local, softmax_lse_paged) denom_local = torch.exp((softmax_lse_local - max_lse).float()) denom_paged = torch.exp((softmax_lse_paged - max_lse).float())In practice with the
.float()cast and typical score magnitudes this is unlikely to trigger, but it's worth considering for robustness at very long context lengths.
177-183: Mask value-1e5may be insufficient for extreme attention scores.Using
-1e5as the masking penalty (Line 183) is less robust than-infortorch.finfo(score.dtype).min. If pre-softmax scores are very large, the masked positions could still receive non-negligible attention weight.
34-69:rewrite_cache_in_left_padding_styleis a trivial wrapper — consider inlining.This function only computes
new_space_size = kv_seq_len + 1and delegates everything to_rewrite_cache_in_left_padding_style_aux. The wrapper adds indirection without meaningful abstraction.experimental/dms/dms/cache_paged.py (2)
567-604:expand_blocksclosure is redefined on every iteration of thewhileloop.The
expand_blocksinner function (Lines 569–585) is recreated each loop iteration. Since it only capturesself.growth_factorwhich doesn't change, move it out of the loop (or make it a method/static helper).Move expand_blocks outside the loop
def _get_free_pages(self, num_pages: int): assert self.free_page_ids is not None assert self.key_blocks is not None assert self.value_blocks is not None + + def expand_blocks(blocks: torch.Tensor): + return torch.cat( + [ + blocks, + torch.zeros( + ( + float_ceil(blocks.size(0) * self.growth_factor) - blocks.size(0), + blocks.size(1), + blocks.size(2), + blocks.size(3), + ), + dtype=blocks.dtype, + device=blocks.device, + ), + ], + dim=0, + ) + while len(self.free_page_ids) < num_pages: - - def expand_blocks(blocks: torch.Tensor): - return torch.cat( - [ - blocks, - torch.zeros( - ( - float_ceil(blocks.size(0) * self.growth_factor) - blocks.size(0), - blocks.size(1), - blocks.size(2), - blocks.size(3), - ), - dtype=blocks.dtype, - device=blocks.device, - ), - ], - dim=0, - ) - old_num_blocks = self.key_blocks.size(0)
545-557:torch.cuda.empty_cache()called unconditionally inreset().If the cache is on a non-CUDA device (e.g., CPU or MPS), this call is unnecessary. Consider guarding it:
if torch.cuda.is_available(): torch.cuda.empty_cache()This is minor since
torch.cuda.empty_cache()is effectively a no-op without initialized CUDA, but it's cleaner and avoids potential issues if CUDA initialization has side effects.experimental/dms/dms/cache.py (1)
519-538:get_crearly-returns1.0if any single layer's paged cache is uninitialized.Line 525–526: if any layer's
cache_seq_lengthsisNone, the method returns1.0immediately, even if other layers are initialized. This is fine during startup but could mask issues if layers initialize out of order.experimental/dms/dms/training/engine.py (1)
329-339:ModifiedTrainer.logassumescustom_logsvalues are numeric tensors whendist.is_initialized().Line 331 creates a CUDA tensor from
values, butcustom_logsmight contain non-numeric types if the contract is ever relaxed. Currently,update_logsconverts everything via.item()(Line 657), so all values are floats — this is safe for now but fragile.
| def score_mod(score, b, h, q_idx, k_idx): | ||
| causal = q_idx >= k_idx | ||
| within_sliding_window = q_idx - k_idx <= window_size | ||
|
|
||
| causal = causal.to(score.dtype) | ||
| within_sliding_window = within_sliding_window.to(score.dtype) | ||
|
|
||
| modified_score = within_sliding_window * score + (1 - within_sliding_window) * ( | ||
| dms_mask_values[b, k_idx] + score | ||
| ) | ||
|
|
||
| return (1 - causal) * MASK_VALUE + causal * modified_score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check window_size comparison operators across training and eval paths
echo "=== Training path (attention.py) ==="
rg -n 'window_size' --type=py -C2 -g '*/dms/attention.py'
echo ""
echo "=== Prefill path (attention_prefill.py) ==="
rg -n 'window_size' --type=py -C2 -g '*/attention_prefill.py'
echo ""
echo "=== Test references ==="
rg -n 'window_size' --type=py -C2 -g '*/test_prefill_and_generate.py'
echo ""
echo "=== Core DMS ==="
rg -n 'window_size' --type=py -C2 -g '*/dms/core.py'Repository: NVIDIA/Model-Optimizer
Length of output: 803
🏁 Script executed:
#!/bin/bash
# Find files first
echo "=== Looking for files ==="
fd -t f 'attention.py'
echo ""
fd -t f 'attention_prefill.py'
echo ""
fd -t f 'test_prefill_and_generate.py'
echo ""
fd -t f 'core.py' | grep -i dms
echo ""
echo "=== Checking for experimental/dms directory ==="
find . -type d -name 'dms' | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 512
🏁 Script executed:
#!/bin/bash
echo "=== attention.py (training) - window_size context ==="
rg -n 'window_size' -C3 experimental/dms/dms/attention.py
echo ""
echo "=== attention_prefill.py (eval) - window_size context ==="
rg -n 'window_size' -C3 experimental/dms/dms/attention_prefill.py
echo ""
echo "=== test_prefill_and_generate.py - window_size context ==="
rg -n 'window_size' -C3 experimental/dms/tests/test_prefill_and_generate.py
echo ""
echo "=== core.py - window_size context ==="
rg -n 'window_size' -C3 experimental/dms/dms/core.pyRepository: NVIDIA/Model-Optimizer
Length of output: 5564
🏁 Script executed:
#!/bin/bash
echo "=== attention_prefill.py get_mask function ==="
sed -n '166,250p' experimental/dms/dms/attention_prefill.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3798
🏁 Script executed:
#!/bin/bash
# Double-check the comparison operators in the key files
echo "=== Training path comparison ==="
sed -n '146,152p' experimental/dms/dms/attention.py
echo ""
echo "=== Eval path comparison ==="
sed -n '176,182p' experimental/dms/dms/attention_prefill.py
echo ""
echo "=== Test prefill reference ==="
sed -n '140,145p' experimental/dms/tests/test_prefill_and_generate.py
echo ""
echo "=== Test eviction logic ==="
sed -n '181,185p' experimental/dms/tests/test_prefill_and_generate.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1213
Fix off-by-one in sliding window boundary comparison.
Training uses <= (line 148: q_idx - k_idx <= window_size), protecting window_size + 1 positions. Eval and test use < (attention_prefill.py line 179: < compressed_attention_mask, test_prefill_and_generate.py line 143: < window_size), protecting exactly window_size positions. This train/eval mismatch means the model sees one extra "safe" token within the sliding window during training compared to inference, potentially causing subtle quality degradation at the window boundary.
🤖 Prompt for AI Agents
In `@experimental/dms/dms/attention.py` around lines 146 - 157, In score_mod, the
sliding-window check uses "<=" which allows window_size+1 tokens during
training; change the comparison in within_sliding_window from "q_idx - k_idx <=
window_size" to "q_idx - k_idx < window_size" so training matches eval/test
behavior; update the expression in the score_mod function (using variables
window_size, within_sliding_window, causal, dms_mask_values) to use the strict
"<" boundary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a desired behavior, as due to compatibility with already published checkpoints DMSCache uses window_size + 1 in inference - see this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check the referenced line and surrounding context in modeling_qwen3_dms.py
echo "=== Checking DMSCache initialization with window_size + 1 ==="
sed -n '375,390p' experimental/dms/models/qwen3/modeling_qwen3_dms.py
echo ""
echo "=== Also check for any comments or documentation about this decision ==="
rg -n "window_size.*\+.*1" experimental/dms/ -C2
Length of output: 2956
@kstaniszewsknv, thank you for the clarification! I can confirm that line 383 in modeling_qwen3_dms.py does indeed initialize the DMSCache with dms_window_size + 1, which means the training path using <= window_size (allowing window_size + 1 positions) is correctly aligned with the inference configuration. This ensures consistency between training and inference for the published checkpoints.
I appreciate you pointing out this design decision—it's an important detail for understanding the DMS architecture.
✏️ Learnings added
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/dms/attention.py:146-157
Timestamp: 2026-02-10T20:27:13.500Z
Learning: In experimental/dms, the training mode attention uses `q_idx - k_idx <= window_size` comparison (allowing window_size + 1 positions) to match the inference DMSCache initialization which uses `dms_window_size + 1`. This design is intentional for backward compatibility with published checkpoints.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
| def dms_schedule( | ||
| step: int, | ||
| training_args: TrainingArguments, | ||
| dms_initial_cr, | ||
| dms_final_cr, | ||
| dms_final_step: int | None = None, | ||
| ): | ||
| """Given the current training step, compute the DMS schedule. | ||
|
|
||
| Returns the target fraction of DMS key-value pairs to evict and the compression ratio. | ||
| """ | ||
| if dms_final_step is not None: | ||
| max_steps = dms_final_step | ||
| else: | ||
| max_steps = training_args.max_steps | ||
|
|
||
| progress = min(step / max_steps, 1.0) | ||
|
|
||
| cr = dms_initial_cr + (dms_final_cr - dms_initial_cr) * progress | ||
|
|
||
| frac = 1 / cr | ||
|
|
||
| target = 1 - frac # what fraction of gates to close | ||
|
|
||
| return target, cr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dms_schedule could divide by zero if max_steps is 0.
Line 238: progress = min(step / max_steps, 1.0) will raise ZeroDivisionError if both dms_final_step and training_args.max_steps are 0.
Add a guard
+ if max_steps <= 0:
+ raise ValueError(f"max_steps must be positive, got {max_steps}")
progress = min(step / max_steps, 1.0)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def dms_schedule( | |
| step: int, | |
| training_args: TrainingArguments, | |
| dms_initial_cr, | |
| dms_final_cr, | |
| dms_final_step: int | None = None, | |
| ): | |
| """Given the current training step, compute the DMS schedule. | |
| Returns the target fraction of DMS key-value pairs to evict and the compression ratio. | |
| """ | |
| if dms_final_step is not None: | |
| max_steps = dms_final_step | |
| else: | |
| max_steps = training_args.max_steps | |
| progress = min(step / max_steps, 1.0) | |
| cr = dms_initial_cr + (dms_final_cr - dms_initial_cr) * progress | |
| frac = 1 / cr | |
| target = 1 - frac # what fraction of gates to close | |
| return target, cr | |
| def dms_schedule( | |
| step: int, | |
| training_args: TrainingArguments, | |
| dms_initial_cr, | |
| dms_final_cr, | |
| dms_final_step: int | None = None, | |
| ): | |
| """Given the current training step, compute the DMS schedule. | |
| Returns the target fraction of DMS key-value pairs to evict and the compression ratio. | |
| """ | |
| if dms_final_step is not None: | |
| max_steps = dms_final_step | |
| else: | |
| max_steps = training_args.max_steps | |
| if max_steps <= 0: | |
| raise ValueError(f"max_steps must be positive, got {max_steps}") | |
| progress = min(step / max_steps, 1.0) | |
| cr = dms_initial_cr + (dms_final_cr - dms_initial_cr) * progress | |
| frac = 1 / cr | |
| target = 1 - frac # what fraction of gates to close | |
| return target, cr |
🤖 Prompt for AI Agents
In `@experimental/dms/dms/training/engine.py` around lines 222 - 246, dms_schedule
can divide by zero when max_steps is 0; add a guard in the function
(dms_schedule) before computing progress: determine max_steps as currently done,
then if max_steps <= 0 set progress = 1.0 (or progress = 0.0 if you prefer
saturating to start) otherwise compute progress = min(step / max_steps, 1.0);
keep the rest of the logic (cr, frac, target) unchanged. This prevents
ZeroDivisionError when dms_final_step and training_args.max_steps are zero.
| def extract_student_model( | ||
| combined_model: CombinedModel, | ||
| tokenizer: PreTrainedTokenizer, | ||
| save_path: str, | ||
| ) -> None: | ||
| """Extract the student model from a CombinedModel and save it for inference. | ||
|
|
||
| The saved model includes: | ||
| - Model weights in bfloat16 | ||
| - Config with auto_map for trust_remote_code | ||
| - Model implementation files (config.py, model.py) | ||
| - Tokenizer | ||
|
|
||
| Note: The saved model imports from the `dms` package. Make sure `dms` is | ||
| installed (pip install -e .) in any environment where you load this model. | ||
| """ | ||
| student_model = combined_model.student_model | ||
| logger.info(f"Extracting student model to: {save_path}") | ||
|
|
||
| student_model.to(torch.bfloat16) | ||
| student_model.save_pretrained(save_path) | ||
| tokenizer.save_pretrained(save_path) | ||
|
|
||
| # Update config.json with auto_map | ||
| config_path = Path(save_path) / "config.json" | ||
| with open(config_path) as f: | ||
| config = json.load(f) | ||
| config.pop("architectures", None) | ||
| config["auto_map"] = AUTO_MAP_CONFIG | ||
| with open(config_path, "w") as f: | ||
| json.dump(config, f, indent=2) | ||
|
|
||
| # Copy model implementation files for trust_remote_code | ||
| model_dir = Path(__file__).parent | ||
| for src_name in ["configuration_qwen3_dms.py", "modeling_qwen3_dms.py"]: | ||
| shutil.copy(model_dir / src_name, Path(save_path) / src_name) | ||
|
|
||
| logger.info(f"Successfully saved student model to: {save_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extract_student_model — to(bfloat16) mutates the in-memory model.
Line 235 converts the student model to bfloat16 in-place. While this is called at the end of training, it permanently mutates combined_model.student_model. If any subsequent code (e.g., evaluation) uses the model, it will unexpectedly be in bfloat16. Consider using a cloned model or documenting this as terminal.
🤖 Prompt for AI Agents
In `@experimental/dms/models/qwen3/train.py` around lines 216 - 253, The code
currently calls student_model.to(torch.bfloat16) which mutates
combined_model.student_model in-place; change extract_student_model to avoid
mutating the in-memory model by creating a deep copy, e.g. use copy.deepcopy to
make student_copy = copy.deepcopy(student_model), call
student_copy.to(torch.bfloat16) and save_pretrained on that copy (and leave
combined_model.student_model untouched); add the necessary import for copy or,
if you prefer, explicitly document that extract_student_model is terminal and
must not be called before any further evaluation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expected behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kstaniszewsknv, understood—thank you for clarifying!
✏️ Learnings added
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/models/qwen3/train.py:216-253
Timestamp: 2026-02-10T20:40:53.280Z
Learning: In experimental/dms/models/qwen3/train.py, the extract_student_model function intentionally performs an in-place conversion to bfloat16 (line 235: student_model.to(torch.bfloat16)). This is expected behavior and the function is designed to be called at the end of training.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
🧠 Learnings used
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/models/qwen3/extract.py:95-96
Timestamp: 2026-02-10T20:39:28.626Z
Learning: In the DMS training pipeline (experimental/dms), training configurations explicitly save checkpoints in legacy PyTorch format (.bin files), not safetensors format, so checkpoint loading uses torch.load.
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/models/qwen3/train.py:82-97
Timestamp: 2026-02-10T20:40:19.544Z
Learning: In experimental/dms/models/qwen3/train.py, the resolve_checkpoint function is designed to leave "auto" in cfg["hf_trainer"]["resume_from_checkpoint"] when no checkpoints are found. This is expected behavior and TrainingArguments handles the "auto" value correctly.
| @@ -0,0 +1,42 @@ | |||
| #!/bin/bash | |||
|
|
|||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright year appears stale.
The header says 2024, but this file is newly introduced in 2026. Consider updating to 2026 (or 2024-2026 if the project convention uses a range).
🤖 Prompt for AI Agents
In `@experimental/dms/scripts/evaluate.sh` at line 3, Update the SPDX copyright
header in experimental/dms/scripts/evaluate.sh (the SPDX-FileCopyrightText line)
to reflect the correct year or range for a new 2026 file—e.g., change "2024" to
"2026" or "2024-2026"—so the header is current and follows project convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This year was added by pre-commit checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kstaniszewsknv, thank you for clarifying! I understand that the copyright headers are managed automatically by your pre-commit hooks. I've noted this for future reference.
✏️ Learnings added
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/scripts/evaluate.sh:3-3
Timestamp: 2026-02-10T20:41:11.637Z
Learning: In this repository, SPDX copyright headers (SPDX-FileCopyrightText) are automatically added by pre-commit checks, so the year in copyright headers is determined by the automated tooling and may not reflect the current year.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
| set -x | ||
|
|
||
| MODEL_PATH=$1 | ||
| test -z "$MODEL_PATH" && echo "Usage: bash scripts/evaluate.sh MODEL_PATH" && exit 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add set -euo pipefail to catch failures.
Without set -e, a failing accelerate launch will let the script exit with code 0, silently masking evaluation errors. This is especially problematic if the script is used in CI or automated pipelines.
Proposed fix
-set -x
+set -euxo pipefail📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| set -x | |
| MODEL_PATH=$1 | |
| test -z "$MODEL_PATH" && echo "Usage: bash scripts/evaluate.sh MODEL_PATH" && exit 1 | |
| set -euxo pipefail | |
| MODEL_PATH=$1 | |
| test -z "$MODEL_PATH" && echo "Usage: bash scripts/evaluate.sh MODEL_PATH" && exit 1 |
🤖 Prompt for AI Agents
In `@experimental/dms/scripts/evaluate.sh` around lines 29 - 32, Add strict bash
failure flags to the top of the script so errors (including pipeline failures)
cause immediate exit: replace or augment the existing "set -x" with "set -euo
pipefail" (and keep -x if debugging is desired) in
experimental/dms/scripts/evaluate.sh so that failures from commands like
accelerate launch do not return a zero exit code; ensure MODEL_PATH handling
(variable MODEL_PATH and the usage check) remains unchanged.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #877 +/- ##
=======================================
Coverage 73.44% 73.44%
=======================================
Files 197 197
Lines 20657 20657
=======================================
Hits 15172 15172
Misses 5485 5485 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: kstaniszewsknv <kstaniszewsk@nvidia.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: kstaniszewsknv <kstaniszewsk@nvidia.com>
Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com>
Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com>
|
@kstaniszewsknv can you please confirm if all the code is NV written, and does not copy code from any 3rd party GitHub repositories? |
The only code that we partially share with HF Transformers library is in |
Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com>
What does this PR do?
Type of change: new feature
Overview: Training and inference code for Dynamic Memory Sparsification (DMS) - method from NeurIPS 2025 paper Inference-Time Hyper-Scaling with KV Cache Compression
Usage
Detailed in
experimental/dms/README.mdandexperimental/dms/ARCHITECTURE.mdTesting
DMS tests in
experimental/dms/testscovering:Before your PR is "Ready for review"
experimental/dmsAdditional Information
A minimal, optimized implementation of the DMS algorithm for KV-cache compression, as described in:
Inference-time scaling trades efficiency for improved reasoning by generating longer sequences. In Transformer LLMs, generation cost is often bottlenecked by the size of the key-value (KV) cache. DMS addresses this by learning a KV cache eviction policy that compresses the cache while preserving accuracy.
How it works
DMS learns a per-head eviction policy that determines which KV cache entries to keep during generation. Rather than immediately discarding tokens, DMS delays eviction decisions, implicitly merging representations and preserving critical information. During training, the compression ratio is gradually increased from 1× to a target value (e.g., 8×), using knowledge distillation to match the outputs of an uncompressed teacher model.
Summary by CodeRabbit
New Features
Documentation
Tests