Skip to content

Conversation

@kstaniszewsknv
Copy link
Contributor

@kstaniszewsknv kstaniszewsknv commented Feb 10, 2026

What does this PR do?

Type of change: new feature

Overview: Training and inference code for Dynamic Memory Sparsification (DMS) - method from NeurIPS 2025 paper Inference-Time Hyper-Scaling with KV Cache Compression

Usage

Detailed in experimental/dms/README.md and experimental/dms/ARCHITECTURE.md

Testing

DMS tests in experimental/dms/tests covering:

  • prefill
  • generation
  • gradient propagation
  • chunked prefill

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No, DMS is currently experimental feature with description in experimental/dms

Additional Information

A minimal, optimized implementation of the DMS algorithm for KV-cache compression, as described in:

Inference-Time Hyper-Scaling with KV Cache Compression
Adrian Łańcucki, Konrad Staniszewski, Piotr Nawrot, Edoardo M. Ponti
Paper: https://arxiv.org/abs/2506.05345
NeurIPS: https://neurips.cc/virtual/2025/loc/san-diego/poster/119605

Inference-time scaling trades efficiency for improved reasoning by generating longer sequences. In Transformer LLMs, generation cost is often bottlenecked by the size of the key-value (KV) cache. DMS addresses this by learning a KV cache eviction policy that compresses the cache while preserving accuracy.

How it works

DMS learns a per-head eviction policy that determines which KV cache entries to keep during generation. Rather than immediately discarding tokens, DMS delays eviction decisions, implicitly merging representations and preserving critical information. During training, the compression ratio is gradually increased from 1× to a target value (e.g., 8×), using knowledge distillation to match the outputs of an uncompressed teacher model.

Summary by CodeRabbit

  • New Features

    • Introduces Dynamic Memory Sparsification (DMS), an algorithm for efficient LLM inference and training with adaptive attention gating.
    • Adds DMS-enabled Qwen3 models with memory-efficient KV cache management and paged block-based storage.
    • Includes student-teacher distillation training infrastructure with noise scheduling and compression ratio control.
    • Provides configuration system and training/evaluation scripts for DMS adaptation.
  • Documentation

    • Added architecture guide, README, and example inference notebook.
  • Tests

    • Added comprehensive test suite for chunked prefill, cache management, and prefill/inference validation.

Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com>
@kstaniszewsknv kstaniszewsknv requested a review from a team as a code owner February 10, 2026 19:28
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 10, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Introduces a comprehensive Dynamic Memory Sparsification (DMS) system for Qwen3 models. Adds documentation, core DMS modules (attention, caching, training logic), configuration files, data pipeline, student-teacher distillation training infrastructure, Qwen3 model adaptations, training/evaluation scripts, and test suite.

Changes

Cohort / File(s) Summary
Documentation
experimental/dms/ARCHITECTURE.md, experimental/dms/README.md
Comprehensive guides covering DMS algorithm, configuration, installation, training workflow, repository structure, and advanced features.
Configuration Files
experimental/dms/configs/qwen3_1.7b.yaml, experimental/dms/configs/qwen3_8b.yaml
YAML training configurations for Qwen3 models with DMS hyperparameters, data settings, and HuggingFace Trainer options.
Core DMS Package
experimental/dms/dms/__init__.py, experimental/dms/dms/logging.py
Package initialization and logging utility for DMS components.
DMS Attention Module
experimental/dms/dms/attention.py, experimental/dms/dms/attention_prefill.py
FlexAttention-based training mode and FlashAttention-based inference mode with eviction-aware sparse masking and cache rewriting.
DMS Cache Management
experimental/dms/dms/cache.py, experimental/dms/dms/cache_paged.py
Contiguous and paged cache layers with block-based storage, eviction decision logic, and mode management (prefill/inference).
DMS Core Logic
experimental/dms/dms/core.py
Attention input/output processing, gating with noise injection, chunked prefill orchestration, and training state management.
Training Data Pipeline
experimental/dms/dms/training/__init__.py, experimental/dms/dms/training/data.py
Dataset loading, tokenization, concatenation, shuffling, and multi-dataset blending with caching for efficient data preparation.
Training Engine
experimental/dms/dms/training/engine.py
Model loading, gradient configuration, noise scheduling, student-teacher distillation loss computation, trainer state, and combined model wrapper.
Qwen3 Model Integration
experimental/dms/models/__init__.py, experimental/dms/models/qwen3/__init__.py, experimental/dms/models/qwen3/configuration_qwen3_dms.py, experimental/dms/models/qwen3/modeling_qwen3_dms.py
DMS-enabled Qwen3 configuration class and model classes (Qwen3AttentionDMS, Qwen3DecoderLayerDMS, Qwen3ForCausalLMDMS, and task-specific variants).
Training and Model Extraction
experimental/dms/models/qwen3/train.py, experimental/dms/models/qwen3/extract.py
End-to-end training workflow with config loading, dataset creation, model building, and student model extraction from checkpoints.
Training and Evaluation Scripts
experimental/dms/scripts/train.sh, experimental/dms/scripts/train_small.sh, experimental/dms/scripts/evaluate.sh
Shell scripts for distributed training, single-GPU debug training, and evaluation via lm-eval-harness.
Project Configuration
experimental/dms/pyproject.toml
Python package metadata, dependencies (transformers, datasets, accelerate, lm_eval), and pytest configuration.
Inference Example and Tests
experimental/dms/example_inference.ipynb, experimental/dms/tests/conftest.py, experimental/dms/tests/test_chunked_prefill.py, experimental/dms/tests/test_dms_utils.py, experimental/dms/tests/test_paged_cache.py, experimental/dms/tests/test_prefill_and_generate.py, experimental/dms/tests/utils.py
Jupyter notebook demonstrating chunked prefill and prefill-only inference workflows, pytest configuration, and comprehensive unit tests validating chunked prefill equivalence, gating utilities, paged cache operations, and prefill-to-generation correctness.

Sequence Diagram(s)

sequenceDiagram
    participant Student as Student Model
    participant Dist as Distillation<br/>Loss
    participant Teacher as Teacher Model
    participant LM as LM Loss
    participant DMS as DMS Loss
    
    Student->>+Student: Forward Pass<br/>(input_ids)
    Student->>Student: Apply DMS<br/>Gating & Attention
    Student-->>-Student: Logits + State
    
    Student->>+Dist: Student Logits
    Teacher->>+Teacher: Forward Pass<br/>(input_ids, eval mode)
    Teacher-->>-Teacher: Teacher Logits
    Teacher->>Dist: Teacher Logits
    Dist->>Dist: KL Divergence<br/>Loss
    Dist-->>-Student: Distillation Loss
    
    Student->>+LM: Student Logits +<br/>Labels
    LM->>LM: Cross Entropy
    LM-->>-Student: LM Loss
    
    Student->>+DMS: DMS State<br/>(frac_closed, decisions)
    DMS->>DMS: Compute DMS<br/>Regularization
    DMS-->>-Student: DMS Loss
    
    Student->>Student: Combine Losses<br/>(α×Dist + β×LM + γ×DMS)
    Student->>Student: Backward Pass &<br/>Optimize
Loading
sequenceDiagram
    participant Input as Input Sequence
    participant Prefill as Chunked Prefill
    participant ChunkAtt as Chunk Attention<br/>Processing
    participant Cache as DMS Cache<br/>(Contiguous +<br/>Paged)
    participant Inference as Token Generation<br/>Loop
    
    Input->>+Prefill: Long Sequence
    Prefill->>Prefill: Split into<br/>Chunks
    
    loop For Each Chunk
        Prefill->>+ChunkAtt: Chunk Tokens +<br/>Position Info
        ChunkAtt->>ChunkAtt: Compute DMS<br/>Decisions & Gating
        ChunkAtt->>ChunkAtt: FlexAttention with<br/>Eviction Masking
        ChunkAtt-->>-Prefill: Chunk Output
        Prefill->>+Cache: Update Recent &<br/>Paged KV
        Cache->>Cache: Evict old tokens<br/>per decisions
        Cache-->>-Prefill: Cache State
    end
    
    Prefill-->>-Inference: Final Hidden State +<br/>Populated Cache
    
    Inference->>+Inference: Generate Loop
    loop For Each Token
        Inference->>+ChunkAtt: Query Token +<br/>Cached KV
        ChunkAtt->>ChunkAtt: FlashAttention with<br/>KV Cache & Paging
        ChunkAtt-->>-Inference: Attention Output
        Inference->>Inference: Compute Logits &<br/>Sample Token
        Inference->>+Cache: Append New Token<br/>to Cache
        Cache->>Cache: Update Cache with<br/>Eviction Check
        Cache-->>-Inference: Updated Cache
    end
    Inference-->>-Inference: Generated Sequence
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 87.69% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title clearly and concisely describes the main change: adding a complete DMS training and inference implementation, which matches the substantial additions across documentation, core modules, models, training infrastructure, and tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kstaniszewsknv/dms

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

🤖 Fix all issues with AI agents
In `@experimental/dms/dms/attention.py`:
- Around line 146-157: In score_mod, the sliding-window check uses "<=" which
allows window_size+1 tokens during training; change the comparison in
within_sliding_window from "q_idx - k_idx <= window_size" to "q_idx - k_idx <
window_size" so training matches eval/test behavior; update the expression in
the score_mod function (using variables window_size, within_sliding_window,
causal, dms_mask_values) to use the strict "<" boundary.

In `@experimental/dms/dms/core.py`:
- Around line 287-298: The function dms_chunked_prefill defines a parameter
named dms_chunked_prefill which shadows the function name; rename the parameter
to dms_chunk_size and update every use inside dms_chunked_prefill (references to
the parameter) to dms_chunk_size, and then update all call sites that pass the
old parameter name to use dms_chunk_size instead (ensure signatures and kwargs
are consistent). Also check any type hints/annotations and docstrings inside
dms_chunked_prefill referencing dms_chunked_prefill and update them to
dms_chunk_size to keep names consistent.

In `@experimental/dms/dms/training/data.py`:
- Around line 445-459: The current __getitem__ mutates self.dataset_iterators
(and uses sample_mapping) which breaks when DataLoader spawns workers; instead
make per-index selection deterministic by removing mutation and computing or
looking up the per-dataset sample index from global index: precompute a static
mapping (e.g. self.sample_index_mapping) in __init__ that for each global index
yields the corresponding (ds_idx, ds_sample_idx), then change __getitem__ to use
ds_idx = self.sample_mapping[index] and ds_sample_idx =
self.sample_index_mapping[index] (or compute ds_sample_idx = index %
len(self.datasets[ds_idx]) if that matches desired behavior) and do not modify
dataset_iterators; update or remove dataset_iterators usage in __getitem__
accordingly.

In `@experimental/dms/dms/training/engine.py`:
- Around line 222-246: dms_schedule can divide by zero when max_steps is 0; add
a guard in the function (dms_schedule) before computing progress: determine
max_steps as currently done, then if max_steps <= 0 set progress = 1.0 (or
progress = 0.0 if you prefer saturating to start) otherwise compute progress =
min(step / max_steps, 1.0); keep the rest of the logic (cr, frac, target)
unchanged. This prevents ZeroDivisionError when dms_final_step and
training_args.max_steps are zero.
- Around line 594-595: The constructor/function currently uses mutable defaults
forward_fn_kwargs_student: dict[str, Any] = {} and forward_fn_kwargs_teacher:
dict[str, Any] = {}, causing shared-state bugs; change both default values to
None (forward_fn_kwargs_student: Optional[dict[str, Any]] = None,
forward_fn_kwargs_teacher: Optional[dict[str, Any]] = None) and inside the
constructor/body of the class/function (where these params are assigned to
self.forward_fn_kwargs_student / self.forward_fn_kwargs_teacher) initialize them
with an explicit None check (e.g., if forward_fn_kwargs_student is None:
self.forward_fn_kwargs_student = {} else: self.forward_fn_kwargs_student =
forward_fn_kwargs_student) and similarly for forward_fn_kwargs_teacher so each
instance gets its own dict.

In `@experimental/dms/models/qwen3/configuration_qwen3_dms.py`:
- Around line 96-97: The constructor currently calls setup_compile_limit_for_dms
when dms_compile_limit is provided, which mutates the global
torch._dynamo.config.cache_size_limit during config instantiation (including
from_pretrained). Remove the side-effect from __init__ by deferring
setup_compile_limit_for_dms to model initialization (e.g., a model factory or an
explicit initialize method) or add clear documentation that __init__ mutates
torch._dynamo.config; modify the code path that currently invokes
setup_compile_limit_for_dms in the config class (the __init__ handling of
dms_compile_limit) and instead call setup_compile_limit_for_dms from the model
loading/initialization routine (the code that constructs or loads the model
after from_pretrained) so deserialization no longer performs global mutation
unless initialization is explicitly performed.

In `@experimental/dms/models/qwen3/extract.py`:
- Around line 95-96: The code currently calls torch.load(model_path,
weights_only=True) in the combined_model.load_state_dict path which will fail if
model_path points to a .safetensors file; update the loader to check the
extension of model_path and use safetensors.torch.load_file(model_path) for
.safetensors and torch.load otherwise (preserving the weights_only behavior
where applicable), then pass the resulting state dict to
combined_model.load_state_dict; refer to logger.info,
combined_model.load_state_dict, torch.load, safetensors.torch.load_file, and
model_path when making the change.
- Around line 62-64: The code currently hardcodes checkpoint_dir /
"pytorch_model.bin" (model_path) which will fail for safetensors, sharded
checkpoints, or other naming; update the logic that builds model_path (using
cli_args.checkpoint and checkpoint_dir) to detect available checkpoint files
instead of a fixed filename: check for model.safetensors, any
pytorch_model-*-of-*.bin shards, and model.*.bin in checkpoint_dir (or use a
HF/transformers utility to resolve the correct weights file), pick and validate
the first matching path, and keep save_path as before; ensure the detection
happens where model_path is assigned so functions referencing model_path use the
resolved file.

In `@experimental/dms/models/qwen3/train.py`:
- Around line 216-253: The code currently calls student_model.to(torch.bfloat16)
which mutates combined_model.student_model in-place; change
extract_student_model to avoid mutating the in-memory model by creating a deep
copy, e.g. use copy.deepcopy to make student_copy =
copy.deepcopy(student_model), call student_copy.to(torch.bfloat16) and
save_pretrained on that copy (and leave combined_model.student_model untouched);
add the necessary import for copy or, if you prefer, explicitly document that
extract_student_model is terminal and must not be called before any further
evaluation.
- Around line 82-97: resolve_checkpoint returns None when no checkpoints are
found but leaves cfg["hf_trainer"]["resume_from_checkpoint"] set to "auto",
which later gets passed to TrainingArguments(**hf_trainer_cfg); change the code
that calls resolve_checkpoint (the place handling checkpoint_path/line ~297) so
that when resolve_checkpoint returns a falsy value you explicitly set
cfg["hf_trainer"]["resume_from_checkpoint"] = None (or delete that key) before
constructing TrainingArguments(**hf_trainer_cfg) — ensure the symbol names
involved are resolve_checkpoint, cfg["hf_trainer"], resume_from_checkpoint and
TrainingArguments(**hf_trainer_cfg).

In `@experimental/dms/pyproject.toml`:
- Line 8: The requires-python spec in pyproject.toml currently uses
">=3.10,<=3.13" which excludes patch releases (e.g., 3.13.1); update the
requires-python constraint to ">=3.10,<3.14" so patch releases within 3.13 are
allowed (locate and edit the requires-python entry in
experimental/dms/pyproject.toml).

In `@experimental/dms/scripts/evaluate.sh`:
- Line 3: Update the SPDX copyright header in
experimental/dms/scripts/evaluate.sh (the SPDX-FileCopyrightText line) to
reflect the correct year or range for a new 2026 file—e.g., change "2024" to
"2026" or "2024-2026"—so the header is current and follows project convention.
- Around line 29-32: Add strict bash failure flags to the top of the script so
errors (including pipeline failures) cause immediate exit: replace or augment
the existing "set -x" with "set -euo pipefail" (and keep -x if debugging is
desired) in experimental/dms/scripts/evaluate.sh so that failures from commands
like accelerate launch do not return a zero exit code; ensure MODEL_PATH
handling (variable MODEL_PATH and the usage check) remains unchanged.

In `@experimental/dms/scripts/train_small.sh`:
- Line 28: The exported env var uses the wrong name: replace PYTORCH_ALLOC_CONF
with PYTORCH_CUDA_ALLOC_CONF in the export line so the expandable_segments:True
setting actually takes effect on CUDA; update the export in train_small.sh (the
line exporting PYTORCH_ALLOC_CONF) to export
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to match the usage in train.sh
and prevent silent non-application of the memory optimization.
🧹 Nitpick comments (26)
experimental/dms/pyproject.toml (2)

20-26: Consider narrowing the DeprecationWarning filter.

Blanket ignore::DeprecationWarning can mask actionable deprecation signals from your own code or key dependencies. Consider scoping it to specific noisy modules (e.g., ignore::DeprecationWarning:transformers). Not urgent for experimental code, but worth keeping in mind.


10-15: Pin lm_eval to a specific version for consistency.

The other dependencies are pinned to exact versions for reproducibility (transformers==4.57.3, datasets==4.4.2, accelerate==1.4.0), but lm_eval[ruler] is unpinned. Since multiple versions of lm_eval exist (0.4.2–0.4.10), a breaking update could silently break evaluation. Pin it to match the pattern of the other dependencies.

experimental/dms/dms/logging.py (1)

22-35: Set logger.propagate = False to prevent duplicate log messages.

In a training environment where HuggingFace Trainer or accelerate typically configure the root logger, messages from DMS loggers will propagate upward and be printed twice (once by this handler, once by the root). Setting propagate = False ensures only the DMS formatter is used.

Proposed fix
     if not logger.handlers:
         logger.setLevel(logging.INFO)
         handler = logging.StreamHandler(sys.stderr)
         handler.setFormatter(
             logging.Formatter("[%(asctime)s] DMS %(name)s [%(levelname)s]: %(message)s")
         )
         logger.addHandler(handler)
+        logger.propagate = False
 
     return logger
experimental/dms/tests/utils.py (1)

25-27: sys.path manipulation is fragile; consider relying on the editable install instead.

Since the README instructs users to run pip install -e ., the DMS package should already be importable without path hacking. If this is kept as a fallback, consider guarding against duplicate appends and using sys.path.insert(0, ...) to ensure the local version takes priority over a stale installed copy.

experimental/dms/dms/training/data.py (3)

429-440: Weighted sample counts may silently drop samples due to int() truncation.

int(nw * train_samples) truncates rather than rounds, so the total number of entries in sample_mapping can be less than train_samples. For three equally-weighted datasets with train_samples=4000, each gets int(0.333… × 4000) = 1333, totaling 3999 instead of 4000.

Consider distributing the remainder to avoid the shortfall:

♻️ Proposed fix
-        self.sample_mapping = []
-        for i, nw in enumerate(self.normalized_weights):
-            nw = nw.item()
-            self.sample_mapping.append(np.full(int(nw * train_samples), i))
+        # Distribute samples proportionally, then assign remainder to largest-weight datasets
+        raw_counts = self.normalized_weights * train_samples
+        counts = raw_counts.astype(int)
+        remainder = train_samples - counts.sum()
+        # Assign remaining samples by largest fractional part
+        fractions = raw_counts - counts
+        for idx in np.argsort(-fractions)[:remainder]:
+            counts[idx] += 1
+
+        self.sample_mapping = []
+        for i, count in enumerate(counts):
+            self.sample_mapping.append(np.full(count, i))

478-497: extract_fn always picks the last correct solution rather than the first.

The loop at lines 482–484 iterates through all (generation, correctness) pairs without breaking, so solution ends up as the last verified-correct answer. If the intent is to use any correct solution, breaking on the first match is more efficient and deterministic.

♻️ Proposed fix: break on first match
     for gen, correctness in zip(ds_elem["generations"], ds_elem["correctness_math_verify"]):
         if correctness:
             solution = gen
+            break

286-303: Remove encoded_prompt_untrimmed or document its use.

Line 289 stores the full token list in encoded_prompt_untrimmed, which persists in the cached dataset. However, this field is never referenced anywhere else in the codebase. If nothing downstream consumes it, this wastes disk space and memory—especially for large datasets with long sequences. Either remove this field or add a comment explaining why it's retained.

experimental/dms/scripts/train_small.sh (1)

39-41: No separate dataset preparation step, unlike train.sh.

train.sh runs --prepare-dataset-only before training to avoid potential issues with dataset processing during distributed training. This script skips that step. For a single-GPU debug script this is unlikely to cause problems, but for consistency you may want to add it — especially since the first run will block on dataset processing before training begins (no progress feedback).

experimental/dms/tests/test_paged_cache.py (1)

255-403: Consider guarding against potential edge case with dms_window_size range.

Lines 267–269: when block_size is at its maximum (15, from randint(1, 16)), the dms_window_size range becomes randint(16, 17), always yielding 16. This works but means dms_window_size = block_size + 1 is the only value tested for large block_size. Not a bug, but you may want to widen max_val or decouple limits to get broader coverage of the dms_window_size > block_size invariant.

experimental/dms/example_inference.ipynb (1)

57-81: clean_cache() placement — model memory not freed after example_prefill_generate.

clean_cache() is called before example_prefill_generate, but the model and tensors created inside the function are not explicitly deleted afterward. Since the notebook's next cell also starts with clean_cache(), this works in practice, but calling clean_cache() after each example (or using a context manager / del model) would be more robust for users running cells out of order.

experimental/dms/models/qwen3/configuration_qwen3_dms.py (1)

86-97: Use ValueError instead of assert for input validation.

These assertions validate user-provided configuration values but can be silently disabled with python -O. For a configuration class that guards runtime invariants, raising ValueError is more robust.

Proposed fix
-        assert self.dms_paged_attention_block_size > 0, (
-            f"dms_paged_attention_block_size: {self.dms_paged_attention_block_size} is not greater than 0"
-        )
-        assert self.dms_window_size > self.dms_paged_attention_block_size, (
-            f"dms_window_size: {self.dms_window_size} "
-            f"is not greater than dms_paged_attention_block_size: {self.dms_paged_attention_block_size}"
-        )
-        assert self.dms_alpha_per in ["head", "layer"], (
-            f"dms_alpha_per: {self.dms_alpha_per} is not supported"
-        )
+        if self.dms_paged_attention_block_size <= 0:
+            raise ValueError(
+                f"dms_paged_attention_block_size must be > 0, got {self.dms_paged_attention_block_size}"
+            )
+        if self.dms_window_size <= self.dms_paged_attention_block_size:
+            raise ValueError(
+                f"dms_window_size ({self.dms_window_size}) must be > "
+                f"dms_paged_attention_block_size ({self.dms_paged_attention_block_size})"
+            )
+        if self.dms_alpha_per not in ("head", "layer"):
+            raise ValueError(f"dms_alpha_per must be 'head' or 'layer', got '{self.dms_alpha_per}'")
experimental/dms/tests/test_prefill_and_generate.py (1)

198-219: Inconsistent .cuda() usage for torch.randint in parameter generation.

Lines 202–208 call .cuda().item() on random integers, while line 201 uses .item() directly. Since torch.randint with .item() returns a Python int regardless of device, the .cuda() calls are unnecessary overhead.

Proposed cleanup
     torch.manual_seed(seed)
     batch = torch.randint(1, 5, (1,)).item()
-    heads_kv = torch.randint(1, 5, (1,)).cuda().item()
-    gqa_factor = torch.randint(1, 4, (1,)).cuda().item()
-    seq_len = torch.randint(8, 1024, (1,)).cuda().item()
+    heads_kv = torch.randint(1, 5, (1,)).item()
+    gqa_factor = torch.randint(1, 4, (1,)).item()
+    seq_len = torch.randint(8, 1024, (1,)).item()
     head_dim = 3
-    chunk_size = torch.randint(1, 128, (1,)).cuda().item()
-    dms_block_size = torch.randint(2, 32, (1,)).cuda().item()
-    dms_window_size = torch.randint(dms_block_size + 1, 128, (1,)).cuda().item()
+    chunk_size = torch.randint(1, 128, (1,)).item()
+    dms_block_size = torch.randint(2, 32, (1,)).item()
+    dms_window_size = torch.randint(dms_block_size + 1, 128, (1,)).item()
experimental/dms/models/qwen3/train.py (2)

105-123: _parse_data_blend_elements — no error handling for malformed blend strings.

Line 118 entry.split(":") will produce unexpected results if an entry doesn't contain exactly one ":". Consider using split(":", 1) and validating the result to provide a clearer error message than an index error.


61-61: Mutable default argument {} on train_attn_kwargs.

Line 61 in dms_attention and line 114 in dms_attn_train_mode (in attention.py) use dict literals as default arguments. While these specific callsites don't mutate the dict, the pattern is a well-known Python pitfall. Consider using None with an or {} fallback.

Also applies to: 114-114

experimental/dms/dms/attention.py (1)

32-42: Guarded imports default to None, but are used as default arguments on lines 190–191.

If flash_attn or dms.attention_prefill fails to import, flash_attn_with_kvcache and dms_run_prefill_flex become None. These are then used as default parameter values for dms_attn_eval_mode, leading to a confusing TypeError: 'NoneType' object is not callable at runtime instead of a clear import error.

Consider raising a descriptive error when the defaults are None and the function is called:

Proposed defensive check
 def dms_attn_eval_mode(
     ...
     flash_attn_fn: Callable = flash_attn_with_kvcache,
     prefill_attn_fn: Callable = dms_run_prefill_flex,
     prefill_attn_fn_kwargs: dict = {},
 ):
     """Perform DMS attention in evaluation mode using flash attention or flex prefill."""
+    if flash_attn_fn is None:
+        raise ImportError("flash_attn_with_kvcache is required for eval mode but failed to import")
+    if prefill_attn_fn is None:
+        raise ImportError("dms_run_prefill_flex is required for prefill mode but failed to import")
     assert decisions.dtype in (torch.int32, torch.long), (
experimental/dms/models/qwen3/modeling_qwen3_dms.py (3)

96-158: position_ids and use_cache silently absorbed by **kwargs.

Qwen3DecoderLayerDMS.forward passes position_ids and use_cache to self.self_attn(...), but Qwen3AttentionDMS.forward doesn't declare these parameters — they silently end up in **kwargs. This works but makes the interface misleading. Consider either accepting them explicitly (and ignoring) or not passing them.


406-415: Silently replacing non-DMSCache past_key_values may mask integration bugs.

When past_key_values is not None and not a DMSCache, the code logs a warning and replaces it with a fresh empty cache (Line 415), discarding whatever state was passed. This silent replacement during generation could hide bugs in callers that pass an incompatible cache type. Consider raising an error instead of silently replacing, or at least making this a logger.error.


337-350: Redundant slice self.layers[: self.config.num_hidden_layers].

self.layers is constructed at Line 259–263 with exactly config.num_hidden_layers elements, so slicing to [:self.config.num_hidden_layers] is a no-op.

experimental/dms/dms/core.py (1)

119-121: @torch.compile() on a function receiving nn.Module parameters may cause excessive recompilations.

prepare_attention_input receives multiple nn.Module instances (projections, norms) as arguments. Each unique module instance (i.e., from each layer) will likely trigger a separate compile guard, leading to one compilation per layer. This is presumably the reason for setup_compile_limit_for_dms(compile_limit=72), but it's worth documenting this relationship explicitly.

experimental/dms/dms/attention_prefill.py (3)

393-406: LSE-based attention merging could overflow for very long sequences.

denom_local = torch.exp(softmax_lse_local.float()) and denom_paged = torch.exp(softmax_lse_paged.float()) compute raw sum-of-exponentials. For extremely long sequences with high attention scores, exp(lse) can overflow float32. The numerically stable approach is to subtract the max LSE before exponentiating:

max_lse = torch.maximum(softmax_lse_local, softmax_lse_paged)
denom_local = torch.exp((softmax_lse_local - max_lse).float())
denom_paged = torch.exp((softmax_lse_paged - max_lse).float())

In practice with the .float() cast and typical score magnitudes this is unlikely to trigger, but it's worth considering for robustness at very long context lengths.


177-183: Mask value -1e5 may be insufficient for extreme attention scores.

Using -1e5 as the masking penalty (Line 183) is less robust than -inf or torch.finfo(score.dtype).min. If pre-softmax scores are very large, the masked positions could still receive non-negligible attention weight.


34-69: rewrite_cache_in_left_padding_style is a trivial wrapper — consider inlining.

This function only computes new_space_size = kv_seq_len + 1 and delegates everything to _rewrite_cache_in_left_padding_style_aux. The wrapper adds indirection without meaningful abstraction.

experimental/dms/dms/cache_paged.py (2)

567-604: expand_blocks closure is redefined on every iteration of the while loop.

The expand_blocks inner function (Lines 569–585) is recreated each loop iteration. Since it only captures self.growth_factor which doesn't change, move it out of the loop (or make it a method/static helper).

Move expand_blocks outside the loop
     def _get_free_pages(self, num_pages: int):
         assert self.free_page_ids is not None
         assert self.key_blocks is not None
         assert self.value_blocks is not None
+
+        def expand_blocks(blocks: torch.Tensor):
+            return torch.cat(
+                [
+                    blocks,
+                    torch.zeros(
+                        (
+                            float_ceil(blocks.size(0) * self.growth_factor) - blocks.size(0),
+                            blocks.size(1),
+                            blocks.size(2),
+                            blocks.size(3),
+                        ),
+                        dtype=blocks.dtype,
+                        device=blocks.device,
+                    ),
+                ],
+                dim=0,
+            )
+
         while len(self.free_page_ids) < num_pages:
-
-            def expand_blocks(blocks: torch.Tensor):
-                return torch.cat(
-                    [
-                        blocks,
-                        torch.zeros(
-                            (
-                                float_ceil(blocks.size(0) * self.growth_factor) - blocks.size(0),
-                                blocks.size(1),
-                                blocks.size(2),
-                                blocks.size(3),
-                            ),
-                            dtype=blocks.dtype,
-                            device=blocks.device,
-                        ),
-                    ],
-                    dim=0,
-                )
-
             old_num_blocks = self.key_blocks.size(0)

545-557: torch.cuda.empty_cache() called unconditionally in reset().

If the cache is on a non-CUDA device (e.g., CPU or MPS), this call is unnecessary. Consider guarding it:

if torch.cuda.is_available():
    torch.cuda.empty_cache()

This is minor since torch.cuda.empty_cache() is effectively a no-op without initialized CUDA, but it's cleaner and avoids potential issues if CUDA initialization has side effects.

experimental/dms/dms/cache.py (1)

519-538: get_cr early-returns 1.0 if any single layer's paged cache is uninitialized.

Line 525–526: if any layer's cache_seq_lengths is None, the method returns 1.0 immediately, even if other layers are initialized. This is fine during startup but could mask issues if layers initialize out of order.

experimental/dms/dms/training/engine.py (1)

329-339: ModifiedTrainer.log assumes custom_logs values are numeric tensors when dist.is_initialized().

Line 331 creates a CUDA tensor from values, but custom_logs might contain non-numeric types if the contract is ever relaxed. Currently, update_logs converts everything via .item() (Line 657), so all values are floats — this is safe for now but fragile.

Comment on lines +146 to +157
def score_mod(score, b, h, q_idx, k_idx):
causal = q_idx >= k_idx
within_sliding_window = q_idx - k_idx <= window_size

causal = causal.to(score.dtype)
within_sliding_window = within_sliding_window.to(score.dtype)

modified_score = within_sliding_window * score + (1 - within_sliding_window) * (
dms_mask_values[b, k_idx] + score
)

return (1 - causal) * MASK_VALUE + causal * modified_score
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check window_size comparison operators across training and eval paths
echo "=== Training path (attention.py) ==="
rg -n 'window_size' --type=py -C2 -g '*/dms/attention.py'

echo ""
echo "=== Prefill path (attention_prefill.py) ==="
rg -n 'window_size' --type=py -C2 -g '*/attention_prefill.py'

echo ""
echo "=== Test references ==="
rg -n 'window_size' --type=py -C2 -g '*/test_prefill_and_generate.py'

echo ""
echo "=== Core DMS ==="
rg -n 'window_size' --type=py -C2 -g '*/dms/core.py'

Repository: NVIDIA/Model-Optimizer

Length of output: 803


🏁 Script executed:

#!/bin/bash
# Find files first
echo "=== Looking for files ==="
fd -t f 'attention.py' 
echo ""
fd -t f 'attention_prefill.py'
echo ""
fd -t f 'test_prefill_and_generate.py'
echo ""
fd -t f 'core.py' | grep -i dms

echo ""
echo "=== Checking for experimental/dms directory ==="
find . -type d -name 'dms' | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 512


🏁 Script executed:

#!/bin/bash
echo "=== attention.py (training) - window_size context ==="
rg -n 'window_size' -C3 experimental/dms/dms/attention.py

echo ""
echo "=== attention_prefill.py (eval) - window_size context ==="
rg -n 'window_size' -C3 experimental/dms/dms/attention_prefill.py

echo ""
echo "=== test_prefill_and_generate.py - window_size context ==="
rg -n 'window_size' -C3 experimental/dms/tests/test_prefill_and_generate.py

echo ""
echo "=== core.py - window_size context ==="
rg -n 'window_size' -C3 experimental/dms/dms/core.py

Repository: NVIDIA/Model-Optimizer

Length of output: 5564


🏁 Script executed:

#!/bin/bash
echo "=== attention_prefill.py get_mask function ==="
sed -n '166,250p' experimental/dms/dms/attention_prefill.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3798


🏁 Script executed:

#!/bin/bash
# Double-check the comparison operators in the key files
echo "=== Training path comparison ==="
sed -n '146,152p' experimental/dms/dms/attention.py

echo ""
echo "=== Eval path comparison ==="
sed -n '176,182p' experimental/dms/dms/attention_prefill.py

echo ""
echo "=== Test prefill reference ==="
sed -n '140,145p' experimental/dms/tests/test_prefill_and_generate.py

echo ""
echo "=== Test eviction logic ==="
sed -n '181,185p' experimental/dms/tests/test_prefill_and_generate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1213


Fix off-by-one in sliding window boundary comparison.

Training uses <= (line 148: q_idx - k_idx <= window_size), protecting window_size + 1 positions. Eval and test use < (attention_prefill.py line 179: < compressed_attention_mask, test_prefill_and_generate.py line 143: < window_size), protecting exactly window_size positions. This train/eval mismatch means the model sees one extra "safe" token within the sliding window during training compared to inference, potentially causing subtle quality degradation at the window boundary.

🤖 Prompt for AI Agents
In `@experimental/dms/dms/attention.py` around lines 146 - 157, In score_mod, the
sliding-window check uses "<=" which allows window_size+1 tokens during
training; change the comparison in within_sliding_window from "q_idx - k_idx <=
window_size" to "q_idx - k_idx < window_size" so training matches eval/test
behavior; update the expression in the score_mod function (using variables
window_size, within_sliding_window, causal, dms_mask_values) to use the strict
"<" boundary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a desired behavior, as due to compatibility with already published checkpoints DMSCache uses window_size + 1 in inference - see this line.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check the referenced line and surrounding context in modeling_qwen3_dms.py
echo "=== Checking DMSCache initialization with window_size + 1 ==="
sed -n '375,390p' experimental/dms/models/qwen3/modeling_qwen3_dms.py

echo ""
echo "=== Also check for any comments or documentation about this decision ==="
rg -n "window_size.*\+.*1" experimental/dms/ -C2

Length of output: 2956


@kstaniszewsknv, thank you for the clarification! I can confirm that line 383 in modeling_qwen3_dms.py does indeed initialize the DMSCache with dms_window_size + 1, which means the training path using <= window_size (allowing window_size + 1 positions) is correctly aligned with the inference configuration. This ensures consistency between training and inference for the published checkpoints.

I appreciate you pointing out this design decision—it's an important detail for understanding the DMS architecture.


✏️ Learnings added
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/dms/attention.py:146-157
Timestamp: 2026-02-10T20:27:13.500Z
Learning: In experimental/dms, the training mode attention uses `q_idx - k_idx <= window_size` comparison (allowing window_size + 1 positions) to match the inference DMSCache initialization which uses `dms_window_size + 1`. This design is intentional for backward compatibility with published checkpoints.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

Comment on lines +222 to +246
def dms_schedule(
step: int,
training_args: TrainingArguments,
dms_initial_cr,
dms_final_cr,
dms_final_step: int | None = None,
):
"""Given the current training step, compute the DMS schedule.

Returns the target fraction of DMS key-value pairs to evict and the compression ratio.
"""
if dms_final_step is not None:
max_steps = dms_final_step
else:
max_steps = training_args.max_steps

progress = min(step / max_steps, 1.0)

cr = dms_initial_cr + (dms_final_cr - dms_initial_cr) * progress

frac = 1 / cr

target = 1 - frac # what fraction of gates to close

return target, cr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

dms_schedule could divide by zero if max_steps is 0.

Line 238: progress = min(step / max_steps, 1.0) will raise ZeroDivisionError if both dms_final_step and training_args.max_steps are 0.

Add a guard
+    if max_steps <= 0:
+        raise ValueError(f"max_steps must be positive, got {max_steps}")
     progress = min(step / max_steps, 1.0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def dms_schedule(
step: int,
training_args: TrainingArguments,
dms_initial_cr,
dms_final_cr,
dms_final_step: int | None = None,
):
"""Given the current training step, compute the DMS schedule.
Returns the target fraction of DMS key-value pairs to evict and the compression ratio.
"""
if dms_final_step is not None:
max_steps = dms_final_step
else:
max_steps = training_args.max_steps
progress = min(step / max_steps, 1.0)
cr = dms_initial_cr + (dms_final_cr - dms_initial_cr) * progress
frac = 1 / cr
target = 1 - frac # what fraction of gates to close
return target, cr
def dms_schedule(
step: int,
training_args: TrainingArguments,
dms_initial_cr,
dms_final_cr,
dms_final_step: int | None = None,
):
"""Given the current training step, compute the DMS schedule.
Returns the target fraction of DMS key-value pairs to evict and the compression ratio.
"""
if dms_final_step is not None:
max_steps = dms_final_step
else:
max_steps = training_args.max_steps
if max_steps <= 0:
raise ValueError(f"max_steps must be positive, got {max_steps}")
progress = min(step / max_steps, 1.0)
cr = dms_initial_cr + (dms_final_cr - dms_initial_cr) * progress
frac = 1 / cr
target = 1 - frac # what fraction of gates to close
return target, cr
🤖 Prompt for AI Agents
In `@experimental/dms/dms/training/engine.py` around lines 222 - 246, dms_schedule
can divide by zero when max_steps is 0; add a guard in the function
(dms_schedule) before computing progress: determine max_steps as currently done,
then if max_steps <= 0 set progress = 1.0 (or progress = 0.0 if you prefer
saturating to start) otherwise compute progress = min(step / max_steps, 1.0);
keep the rest of the logic (cr, frac, target) unchanged. This prevents
ZeroDivisionError when dms_final_step and training_args.max_steps are zero.

Comment on lines +216 to +253
def extract_student_model(
combined_model: CombinedModel,
tokenizer: PreTrainedTokenizer,
save_path: str,
) -> None:
"""Extract the student model from a CombinedModel and save it for inference.

The saved model includes:
- Model weights in bfloat16
- Config with auto_map for trust_remote_code
- Model implementation files (config.py, model.py)
- Tokenizer

Note: The saved model imports from the `dms` package. Make sure `dms` is
installed (pip install -e .) in any environment where you load this model.
"""
student_model = combined_model.student_model
logger.info(f"Extracting student model to: {save_path}")

student_model.to(torch.bfloat16)
student_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

# Update config.json with auto_map
config_path = Path(save_path) / "config.json"
with open(config_path) as f:
config = json.load(f)
config.pop("architectures", None)
config["auto_map"] = AUTO_MAP_CONFIG
with open(config_path, "w") as f:
json.dump(config, f, indent=2)

# Copy model implementation files for trust_remote_code
model_dir = Path(__file__).parent
for src_name in ["configuration_qwen3_dms.py", "modeling_qwen3_dms.py"]:
shutil.copy(model_dir / src_name, Path(save_path) / src_name)

logger.info(f"Successfully saved student model to: {save_path}")
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

extract_student_modelto(bfloat16) mutates the in-memory model.

Line 235 converts the student model to bfloat16 in-place. While this is called at the end of training, it permanently mutates combined_model.student_model. If any subsequent code (e.g., evaluation) uses the model, it will unexpectedly be in bfloat16. Consider using a cloned model or documenting this as terminal.

🤖 Prompt for AI Agents
In `@experimental/dms/models/qwen3/train.py` around lines 216 - 253, The code
currently calls student_model.to(torch.bfloat16) which mutates
combined_model.student_model in-place; change extract_student_model to avoid
mutating the in-memory model by creating a deep copy, e.g. use copy.deepcopy to
make student_copy = copy.deepcopy(student_model), call
student_copy.to(torch.bfloat16) and save_pretrained on that copy (and leave
combined_model.student_model untouched); add the necessary import for copy or,
if you prefer, explicitly document that extract_student_model is terminal and
must not be called before any further evaluation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expected behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kstaniszewsknv, understood—thank you for clarifying!


✏️ Learnings added
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/models/qwen3/train.py:216-253
Timestamp: 2026-02-10T20:40:53.280Z
Learning: In experimental/dms/models/qwen3/train.py, the extract_student_model function intentionally performs an in-place conversion to bfloat16 (line 235: student_model.to(torch.bfloat16)). This is expected behavior and the function is designed to be called at the end of training.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/models/qwen3/extract.py:95-96
Timestamp: 2026-02-10T20:39:28.626Z
Learning: In the DMS training pipeline (experimental/dms), training configurations explicitly save checkpoints in legacy PyTorch format (.bin files), not safetensors format, so checkpoint loading uses torch.load.

Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/models/qwen3/train.py:82-97
Timestamp: 2026-02-10T20:40:19.544Z
Learning: In experimental/dms/models/qwen3/train.py, the resolve_checkpoint function is designed to leave "auto" in cfg["hf_trainer"]["resume_from_checkpoint"] when no checkpoints are found. This is expected behavior and TrainingArguments handles the "auto" value correctly.

@@ -0,0 +1,42 @@
#!/bin/bash

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Copyright year appears stale.

The header says 2024, but this file is newly introduced in 2026. Consider updating to 2026 (or 2024-2026 if the project convention uses a range).

🤖 Prompt for AI Agents
In `@experimental/dms/scripts/evaluate.sh` at line 3, Update the SPDX copyright
header in experimental/dms/scripts/evaluate.sh (the SPDX-FileCopyrightText line)
to reflect the correct year or range for a new 2026 file—e.g., change "2024" to
"2026" or "2024-2026"—so the header is current and follows project convention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This year was added by pre-commit checks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kstaniszewsknv, thank you for clarifying! I understand that the copyright headers are managed automatically by your pre-commit hooks. I've noted this for future reference.


✏️ Learnings added
Learnt from: kstaniszewsknv
Repo: NVIDIA/Model-Optimizer PR: 877
File: experimental/dms/scripts/evaluate.sh:3-3
Timestamp: 2026-02-10T20:41:11.637Z
Learning: In this repository, SPDX copyright headers (SPDX-FileCopyrightText) are automatically added by pre-commit checks, so the year in copyright headers is determined by the automated tooling and may not reflect the current year.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

Comment on lines +29 to +32
set -x

MODEL_PATH=$1
test -z "$MODEL_PATH" && echo "Usage: bash scripts/evaluate.sh MODEL_PATH" && exit 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add set -euo pipefail to catch failures.

Without set -e, a failing accelerate launch will let the script exit with code 0, silently masking evaluation errors. This is especially problematic if the script is used in CI or automated pipelines.

Proposed fix
-set -x
+set -euxo pipefail
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
set -x
MODEL_PATH=$1
test -z "$MODEL_PATH" && echo "Usage: bash scripts/evaluate.sh MODEL_PATH" && exit 1
set -euxo pipefail
MODEL_PATH=$1
test -z "$MODEL_PATH" && echo "Usage: bash scripts/evaluate.sh MODEL_PATH" && exit 1
🤖 Prompt for AI Agents
In `@experimental/dms/scripts/evaluate.sh` around lines 29 - 32, Add strict bash
failure flags to the top of the script so errors (including pipeline failures)
cause immediate exit: replace or augment the existing "set -x" with "set -euo
pipefail" (and keep -x if debugging is desired) in
experimental/dms/scripts/evaluate.sh so that failures from commands like
accelerate launch do not return a zero exit code; ensure MODEL_PATH handling
(variable MODEL_PATH and the usage check) remains unchanged.

@codecov
Copy link

codecov bot commented Feb 10, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.44%. Comparing base (5e43b2a) to head (30c3c48).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #877   +/-   ##
=======================================
  Coverage   73.44%   73.44%           
=======================================
  Files         197      197           
  Lines       20657    20657           
=======================================
  Hits        15172    15172           
  Misses       5485     5485           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

kstaniszewsknv and others added 2 commits February 10, 2026 20:49
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: kstaniszewsknv <kstaniszewsk@nvidia.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: kstaniszewsknv <kstaniszewsk@nvidia.com>
@kstaniszewsknv kstaniszewsknv changed the title Initial DMS train Add Dynamic Memory Sparsification (DMS) training and inference implementation Feb 10, 2026
Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com>
Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com>
@kevalmorabia97
Copy link
Collaborator

@kstaniszewsknv can you please confirm if all the code is NV written, and does not copy code from any 3rd party GitHub repositories?

@kstaniszewsknv
Copy link
Contributor Author

@kstaniszewsknv can you please confirm if all the code is NV written, and does not copy code from any 3rd party GitHub repositories?

The only code that we partially share with HF Transformers library is in experimental/dms/models/qwen3/configuration_qwen3_dms.py and experimental/dms/models/qwen3/modeling_qwen3_dms.py (Apache 2.0).
We have already released a longer version of those two files here https://huggingface.co/nvidia/Qwen3-8B-DMS-8x/tree/main.
Those two files contain appropriate license notes (consulted with legal team).

Signed-off-by: Konrad Staniszewski <kstaniszewsk@nvidia.com>
@kstaniszewsknv kstaniszewsknv merged commit 549cea8 into main Feb 10, 2026
34 checks passed
@kstaniszewsknv kstaniszewsknv deleted the kstaniszewsknv/dms branch February 10, 2026 22:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants