Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions cookbook/rl/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = get_logger()

MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0')))
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS',4))
Expand All @@ -31,15 +31,16 @@
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 16)) # global completion-level mini-batch-size
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))

def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Template', model_id=MODEL_ID, max_length=2048)
dataset.set_template('Template', model_id=MODEL_ID, max_length=400)
dataset.map(GSM8KProcessor())
dataset.encode(add_generation_prompt=True)
return dataset
Expand Down Expand Up @@ -68,13 +69,21 @@ def main():
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)

# lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
lora_config = LoraConfig(
target_modules=[
'q_proj', 'k_proj', 'v_proj', 'o_proj',
'gate_proj', 'up_proj', 'down_proj',
'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj',
],
r=32, lora_alpha=64, lora_dropout=0.05,
)
if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
model = MegatronModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', mixed_precision='bf16')
else:
model = TransformersModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model')
from transformers import Qwen3_5ForConditionalGeneration
model = TransformersModel(model_id=MODEL_ID, model_cls=Qwen3_5ForConditionalGeneration, device_mesh=model_mesh, remote_group='model')

model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
if USE_MEGATRON:
Expand All @@ -91,8 +100,9 @@ def main():
model_id=MODEL_ID,
engine_args={
'gpu_memory_utilization': 0.8,
'max_model_len': 4096,
'max_model_len': 4496,
'max_lora_rank': 32, # save as lora_config
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
'enable_lora': True,
},
device_mesh=sampler_mesh,
Expand Down Expand Up @@ -172,6 +182,8 @@ def main():

if optim_step >= MAX_STEPS:
break
if optim_step % SAVE_STEPS == 0:
model.save(f'grpo-gsm8k-checkpoint-{optim_step}')
log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True))
metrics.reset()
Expand Down
6 changes: 3 additions & 3 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,7 +1587,7 @@ def _trim_vocab(name, tensor):

if base_sync_done and adapter_name:
if merge_and_sync:

# LoRA Training and sync full model(merge_adapter)
def weight_generator():
for _model in self.strategy.unwrap_model(self.model):
if isinstance(_model, PeftModel):
Expand Down Expand Up @@ -1616,7 +1616,7 @@ def weight_generator():
yield name, tensor

else:

# First full base-model sync.
def _raw_weights():
for name, tensor in self.get_hf_state_dict(adapter_name=''):
if name is None or tensor is None:
Expand All @@ -1627,7 +1627,7 @@ def _raw_weights():
yield _trim_vocab(name, tensor)

def weight_generator():
if is_peft_format:
if is_peft_format and not merge_and_sync:
yield from _add_base_layer_suffix(_raw_weights())
else:
yield from _raw_weights()
Expand Down
35 changes: 21 additions & 14 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,21 +1159,28 @@ def send_weights(
# Get state dict from unwrapped model
model = self.strategy.unwrap_model(self.model)

def _normalize(name: str, keep_base_layer: bool) -> str:
name = name.replace('base_model.model.', '')
if not keep_base_layer:
name = name.replace('.base_layer', '')
return name

def _is_lora_key(name: str) -> bool:
return 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name

if base_sync_done and adapter_name:
if merge_and_sync:

# LoRA Training and sync full model(merge_adapter)
# merge and skip lora weigts(already merged)
# trim prefix(base_model.model.) and suffix(.base_layer)
def weight_generator():
if isinstance(model, PeftModel):
model.merge_adapter()
for name, tensor in model.state_dict().items():
# Skip LoRA-specific weights for base model sync
if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
if _is_lora_key(name):
continue
tensor = Torch.to_local_tensor(tensor)
# Keep original names (including .base_layer for PEFT models).
# The sampler side will strip .base_layer based on whether
# vLLM has enable_lora=True/False.
yield name, tensor
yield _normalize(name, keep_base_layer=False), tensor
if isinstance(model, PeftModel):
model.unmerge_adapter()
else:
Expand All @@ -1188,19 +1195,19 @@ def weight_generator():
yield name, tensor

else:
# Full model mode: send all weights (base model sync).
# First full base-model sync. Whether to keep ``.base_layer.``
# depends on whether the sampler uses ``enable_lora``:
# merge_and_sync=True → enable_lora=False → strip .base_layer
# merge_and_sync=False → enable_lora=True → keep .base_layer
keep_base_layer = not merge_and_sync
state_dict = model.state_dict()

def weight_generator():
for name, tensor in state_dict.items():
# Skip LoRA-specific weights for base model sync
if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
if _is_lora_key(name):
continue
tensor = Torch.to_local_tensor(tensor)
# Keep original names (including .base_layer for PEFT models).
# The sampler side will strip .base_layer based on whether
# vLLM has enable_lora=True/False.
yield name, tensor
yield _normalize(name, keep_base_layer=keep_base_layer), tensor

# Run async send_weights in a dedicated event loop thread.
# We cannot use the Ray worker's event loop because it may already
Expand Down
6 changes: 2 additions & 4 deletions src/twinkle/preprocessor/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,8 @@ class GSM8KProcessor(Preprocessor):
Extracts the ground truth number and stores it in user_data for reward.
Only includes system + user messages; assistant response is generated on-policy.
"""
system_prompt = ('You are a helpful math assistant. Solve the problem step by step. '
'Show your reasoning in <think> </think> tags, then give the final '
'numerical answer after ####.\n'
'For example:\n<think> ... reasoning ... </think>\n#### 42')
system_prompt = ('You are a helpful math assistant. Solve the problem step by step '
'and put your final answer within \\boxed{}.')

def __init__(self, system=None, add_assistant=False):
self.system = system
Expand Down
17 changes: 9 additions & 8 deletions src/twinkle/reward/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
class GSM8KAccuracyReward(Reward):
"""Accuracy reward for GSM8K: checks if the model's answer matches ground truth.

Extracts the last '#### <number>' from model output and compares with ground truth.
Extracts the answer from \\boxed{} (preferred) or #### format.
Returns 1.0 for correct, 0.0 for incorrect.
"""

@staticmethod
def extract_answer(completion: str) -> str:
"""Extract the last #### answer from model completion."""
# Only check last 500 chars for efficiency
"""Extract the answer from model completion, preferring \\boxed{} over ####."""
text = completion[-500:] if len(completion) > 500 else completion
boxed = re.findall(r'\\boxed\{([^}]+)\}', text)
if boxed:
return boxed[-1].replace(',', '').replace(' ', '').strip()
matches = re.findall(r'####\s*([\-\d,\.\s]+)', text)
if matches:
return matches[-1].replace(',', '').replace(' ', '').strip()
Expand Down Expand Up @@ -54,9 +56,9 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:


class GSM8KFormatReward(Reward):
"""Format reward: checks if output contains <think>...</think> tag.
"""Format reward: checks if output contains \\boxed{} or #### answer format.

Returns 1.0 if format is correct, 0.0 otherwise.
Returns 1.0 if a valid answer format is present, 0.0 otherwise.
"""

def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
Expand All @@ -68,7 +70,6 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break
has_think = bool(re.search(r'<think>.*?</think>', completion, re.DOTALL))
has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
rewards.append(1.0 if (has_think and has_answer) else 0.0)
has_answer = bool(re.search(r'\\boxed\{[^}]+\}', completion) or re.search(r'####\s*[\-\d,\.]+', completion))
rewards.append(1.0 if has_answer else 0.0)
return rewards
64 changes: 14 additions & 50 deletions src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,17 @@ def _load_weights(
"""Load a batch of weights into vLLM.

Two modes:
- LoRA mode (``peft_config`` and ``base_sync_done``): Loads weights as
a tensor-based LoRA adapter via ``add_lora()``.
- Base model mode: Strips PEFT prefixes, merges split weights
(q/k/v_proj -> qkv_proj, gate/up_proj -> gate_up_proj) into vLLM's
stacked format, normalizes prefixes, then loads via direct param copy.

* **LoRA mode** (``peft_config`` set and ``base_sync_done=True``):
loads weights as a tensor-based LoRA adapter via ``add_lora()``.
* **Base model mode** (all other cases): delegates to
``model.load_weights()`` which handles stacked-parameter merging
(q/k/v → qkv, gate/up → gate_up) and prefix mapping internally.

Weight names are expected to arrive **already normalised** by the
sender (``TransformersModel.send_weights`` /
``MegatronModel.send_weights``), so no name transformation is done
here.
"""
if peft_config and base_sync_done:
# Remove existing LoRA before replacing
Expand All @@ -412,51 +418,9 @@ def _load_weights(
)
self.add_lora(lora_request)
else:
# Base model mode — strip PEFT prefixes and delegate to
# vLLM's model.load_weights() which handles stacked params,
# prefix normalization, and weight_loader internally.
vllm_has_lora = getattr(
getattr(self, 'vllm_config', None),
'lora_config',
None,
) is not None

# When vLLM LoRA is enabled, some LinearBase modules are
# replaced by *WithLoRA wrappers. Their parameters shift
# from e.g. ``gate.weight`` to ``gate.base_layer.weight``.
# HF checkpoint names do NOT contain ``.base_layer.``, so
# vLLM's own ``load_weights`` will KeyError on them.
#
# Build a set of base-layer prefixes that need rewriting.
lora_base_prefixes: set = set()
if vllm_has_lora:
from vllm.lora.layers import BaseLayerWithLoRA
for mod_name, mod in self.model_runner.model.named_modules():
if isinstance(mod, BaseLayerWithLoRA):
# mod_name is e.g. "model.layers.0.mlp.gate"
lora_base_prefixes.add(mod_name + '.')

converted = []
for name, tensor in weights:
if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
continue
name = name.removeprefix('model.base_model.model.')
name = name.removeprefix('base_model.model.')
if not vllm_has_lora:
name = name.replace('.base_layer.', '.')
else:
# Insert ``.base_layer.`` for weights whose module
# has been wrapped by LoRA and whose name does NOT
# already contain it.
if '.base_layer.' not in name:
for pfx in lora_base_prefixes:
if name.startswith(pfx):
# e.g. "model.layers.0.mlp.gate.weight"
# → "model.layers.0.mlp.gate.base_layer.weight"
suffix = name[len(pfx):]
name = pfx + 'base_layer.' + suffix
break
converted.append((name, tensor))
# Base model mode — weights arrive in canonical HF format
converted = [(n, t) for n, t in weights
if 'lora_A' not in n and 'lora_B' not in n and 'lora_embedding' not in n]

if not converted:
return
Expand Down
Loading