From 57ba4bb95db382b6c3a8327a7730076670142f98 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 21 Mar 2026 20:19:53 +0800 Subject: [PATCH 1/2] fix --- cookbook/rl/grpo.py | 14 ++-- src/twinkle/model/megatron/megatron.py | 6 +- .../model/transformers/transformers.py | 35 ++++++---- src/twinkle/preprocessor/llm.py | 6 +- src/twinkle/reward/gsm8k.py | 17 ++--- .../vllm_sampler/vllm_worker_extension.py | 64 ++++--------------- 6 files changed, 58 insertions(+), 84 deletions(-) diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index d7d5df21..e465d503 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -21,7 +21,7 @@ logger = get_logger() MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS',4)) @@ -31,15 +31,16 @@ MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) LEARNING_RATE = float(os.environ.get('LR', 1e-5)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size -MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 16)) # global completion-level mini-batch-size +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) ADAPTER_NAME = 'default' +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50)) def create_gsm8k_dataset(): dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) - dataset.set_template('Template', model_id=MODEL_ID, max_length=2048) + dataset.set_template('Template', model_id=MODEL_ID, max_length=400) dataset.map(GSM8KProcessor()) dataset.encode(add_generation_prompt=True) return dataset @@ -91,8 +92,9 @@ def main(): model_id=MODEL_ID, engine_args={ 'gpu_memory_utilization': 0.8, - 'max_model_len': 4096, + 'max_model_len': 4496, 'max_lora_rank': 32, # save as lora_config + # NOTE: To use enable_lora, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 or later 'enable_lora': True, }, device_mesh=sampler_mesh, @@ -172,6 +174,8 @@ def main(): if optim_step >= MAX_STEPS: break + if optim_step % SAVE_STEPS == 0: + model.save(f'grpo-gsm8k-checkpoint-{optim_step}') log_dict = metrics.calculate() log_dict.update(model.calculate_metric(is_training=True)) metrics.reset() diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index ddddf41c..60cc30a8 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1587,7 +1587,7 @@ def _trim_vocab(name, tensor): if base_sync_done and adapter_name: if merge_and_sync: - + # LoRA Training and sync full model(merge_adapter) def weight_generator(): for _model in self.strategy.unwrap_model(self.model): if isinstance(_model, PeftModel): @@ -1616,7 +1616,7 @@ def weight_generator(): yield name, tensor else: - + # First full base-model sync. def _raw_weights(): for name, tensor in self.get_hf_state_dict(adapter_name=''): if name is None or tensor is None: @@ -1627,7 +1627,7 @@ def _raw_weights(): yield _trim_vocab(name, tensor) def weight_generator(): - if is_peft_format: + if is_peft_format and not merge_and_sync: yield from _add_base_layer_suffix(_raw_weights()) else: yield from _raw_weights() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index d00e80ed..48f08039 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1159,21 +1159,28 @@ def send_weights( # Get state dict from unwrapped model model = self.strategy.unwrap_model(self.model) + def _normalize(name: str, keep_base_layer: bool) -> str: + name = name.replace('base_model.model.', '') + if not keep_base_layer: + name = name.replace('.base_layer', '') + return name + + def _is_lora_key(name: str) -> bool: + return 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name + if base_sync_done and adapter_name: if merge_and_sync: - + # LoRA Training and sync full model(merge_adapter) + # merge and skip lora weigts(already merged) + # trim prefix(base_model.model.) and suffix(.base_layer) def weight_generator(): if isinstance(model, PeftModel): model.merge_adapter() for name, tensor in model.state_dict().items(): - # Skip LoRA-specific weights for base model sync - if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: + if _is_lora_key(name): continue tensor = Torch.to_local_tensor(tensor) - # Keep original names (including .base_layer for PEFT models). - # The sampler side will strip .base_layer based on whether - # vLLM has enable_lora=True/False. - yield name, tensor + yield _normalize(name, keep_base_layer=False), tensor if isinstance(model, PeftModel): model.unmerge_adapter() else: @@ -1188,19 +1195,19 @@ def weight_generator(): yield name, tensor else: - # Full model mode: send all weights (base model sync). + # First full base-model sync. Whether to keep ``.base_layer.`` + # depends on whether the sampler uses ``enable_lora``: + # merge_and_sync=True → enable_lora=False → strip .base_layer + # merge_and_sync=False → enable_lora=True → keep .base_layer + keep_base_layer = not merge_and_sync state_dict = model.state_dict() def weight_generator(): for name, tensor in state_dict.items(): - # Skip LoRA-specific weights for base model sync - if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: + if _is_lora_key(name): continue tensor = Torch.to_local_tensor(tensor) - # Keep original names (including .base_layer for PEFT models). - # The sampler side will strip .base_layer based on whether - # vLLM has enable_lora=True/False. - yield name, tensor + yield _normalize(name, keep_base_layer=keep_base_layer), tensor # Run async send_weights in a dedicated event loop thread. # We cannot use the Ray worker's event loop because it may already diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py index a451e90c..97065fba 100644 --- a/src/twinkle/preprocessor/llm.py +++ b/src/twinkle/preprocessor/llm.py @@ -122,10 +122,8 @@ class GSM8KProcessor(Preprocessor): Extracts the ground truth number and stores it in user_data for reward. Only includes system + user messages; assistant response is generated on-policy. """ - system_prompt = ('You are a helpful math assistant. Solve the problem step by step. ' - 'Show your reasoning in tags, then give the final ' - 'numerical answer after ####.\n' - 'For example:\n ... reasoning ... \n#### 42') + system_prompt = ('You are a helpful math assistant. Solve the problem step by step ' + 'and put your final answer within \\boxed{}.') def __init__(self, system=None, add_assistant=False): self.system = system diff --git a/src/twinkle/reward/gsm8k.py b/src/twinkle/reward/gsm8k.py index 1f0f14b9..eb439675 100644 --- a/src/twinkle/reward/gsm8k.py +++ b/src/twinkle/reward/gsm8k.py @@ -7,15 +7,17 @@ class GSM8KAccuracyReward(Reward): """Accuracy reward for GSM8K: checks if the model's answer matches ground truth. - Extracts the last '#### ' from model output and compares with ground truth. + Extracts the answer from \\boxed{} (preferred) or #### format. Returns 1.0 for correct, 0.0 for incorrect. """ @staticmethod def extract_answer(completion: str) -> str: - """Extract the last #### answer from model completion.""" - # Only check last 500 chars for efficiency + """Extract the answer from model completion, preferring \\boxed{} over ####.""" text = completion[-500:] if len(completion) > 500 else completion + boxed = re.findall(r'\\boxed\{([^}]+)\}', text) + if boxed: + return boxed[-1].replace(',', '').replace(' ', '').strip() matches = re.findall(r'####\s*([\-\d,\.\s]+)', text) if matches: return matches[-1].replace(',', '').replace(' ', '').strip() @@ -54,9 +56,9 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: class GSM8KFormatReward(Reward): - """Format reward: checks if output contains ... tag. + """Format reward: checks if output contains \\boxed{} or #### answer format. - Returns 1.0 if format is correct, 0.0 otherwise. + Returns 1.0 if a valid answer format is present, 0.0 otherwise. """ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: @@ -68,7 +70,6 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: if msg.get('role') == 'assistant': completion = msg.get('content', '') break - has_think = bool(re.search(r'.*?', completion, re.DOTALL)) - has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion)) - rewards.append(1.0 if (has_think and has_answer) else 0.0) + has_answer = bool(re.search(r'\\boxed\{[^}]+\}', completion) or re.search(r'####\s*[\-\d,\.]+', completion)) + rewards.append(1.0 if has_answer else 0.0) return rewards diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index 42be5095..61920cd9 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -390,11 +390,17 @@ def _load_weights( """Load a batch of weights into vLLM. Two modes: - - LoRA mode (``peft_config`` and ``base_sync_done``): Loads weights as - a tensor-based LoRA adapter via ``add_lora()``. - - Base model mode: Strips PEFT prefixes, merges split weights - (q/k/v_proj -> qkv_proj, gate/up_proj -> gate_up_proj) into vLLM's - stacked format, normalizes prefixes, then loads via direct param copy. + + * **LoRA mode** (``peft_config`` set and ``base_sync_done=True``): + loads weights as a tensor-based LoRA adapter via ``add_lora()``. + * **Base model mode** (all other cases): delegates to + ``model.load_weights()`` which handles stacked-parameter merging + (q/k/v → qkv, gate/up → gate_up) and prefix mapping internally. + + Weight names are expected to arrive **already normalised** by the + sender (``TransformersModel.send_weights`` / + ``MegatronModel.send_weights``), so no name transformation is done + here. """ if peft_config and base_sync_done: # Remove existing LoRA before replacing @@ -412,51 +418,9 @@ def _load_weights( ) self.add_lora(lora_request) else: - # Base model mode — strip PEFT prefixes and delegate to - # vLLM's model.load_weights() which handles stacked params, - # prefix normalization, and weight_loader internally. - vllm_has_lora = getattr( - getattr(self, 'vllm_config', None), - 'lora_config', - None, - ) is not None - - # When vLLM LoRA is enabled, some LinearBase modules are - # replaced by *WithLoRA wrappers. Their parameters shift - # from e.g. ``gate.weight`` to ``gate.base_layer.weight``. - # HF checkpoint names do NOT contain ``.base_layer.``, so - # vLLM's own ``load_weights`` will KeyError on them. - # - # Build a set of base-layer prefixes that need rewriting. - lora_base_prefixes: set = set() - if vllm_has_lora: - from vllm.lora.layers import BaseLayerWithLoRA - for mod_name, mod in self.model_runner.model.named_modules(): - if isinstance(mod, BaseLayerWithLoRA): - # mod_name is e.g. "model.layers.0.mlp.gate" - lora_base_prefixes.add(mod_name + '.') - - converted = [] - for name, tensor in weights: - if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: - continue - name = name.removeprefix('model.base_model.model.') - name = name.removeprefix('base_model.model.') - if not vllm_has_lora: - name = name.replace('.base_layer.', '.') - else: - # Insert ``.base_layer.`` for weights whose module - # has been wrapped by LoRA and whose name does NOT - # already contain it. - if '.base_layer.' not in name: - for pfx in lora_base_prefixes: - if name.startswith(pfx): - # e.g. "model.layers.0.mlp.gate.weight" - # → "model.layers.0.mlp.gate.base_layer.weight" - suffix = name[len(pfx):] - name = pfx + 'base_layer.' + suffix - break - converted.append((name, tensor)) + # Base model mode — weights arrive in canonical HF format + converted = [(n, t) for n, t in weights + if 'lora_A' not in n and 'lora_B' not in n and 'lora_embedding' not in n] if not converted: return From 06e02bd08312d11d28288f228ca111ebe4a7d554 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 22 Mar 2026 11:36:06 +0800 Subject: [PATCH 2/2] fix transformers --- cookbook/rl/grpo.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index e465d503..9faa6f5b 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -69,13 +69,21 @@ def main(): sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) - lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05) - + # lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05) + lora_config = LoraConfig( + target_modules=[ + 'q_proj', 'k_proj', 'v_proj', 'o_proj', + 'gate_proj', 'up_proj', 'down_proj', + 'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj', + ], + r=32, lora_alpha=64, lora_dropout=0.05, + ) if USE_MEGATRON: from twinkle.model.megatron import MegatronModel model = MegatronModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', mixed_precision='bf16') else: - model = TransformersModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model') + from transformers import Qwen3_5ForConditionalGeneration + model = TransformersModel(model_id=MODEL_ID, model_cls=Qwen3_5ForConditionalGeneration, device_mesh=model_mesh, remote_group='model') model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) if USE_MEGATRON: @@ -94,7 +102,7 @@ def main(): 'gpu_memory_utilization': 0.8, 'max_model_len': 4496, 'max_lora_rank': 32, # save as lora_config - # NOTE: To use enable_lora, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 or later + # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 'enable_lora': True, }, device_mesh=sampler_mesh,