From 9b349825090400a40fbcea4572de12feb20bdc2b Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Thu, 5 Feb 2026 02:01:16 -0500 Subject: [PATCH] Trl 0.27.0 update (#3965) * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update rl_replacements.py * Update rl.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update rl_replacements.py, remove chat template from codexes commits * Update rl.py, got rid of gradient checkpointing code that did not work --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- unsloth/models/rl.py | 27 +++++ unsloth/models/rl_replacements.py | 167 +++++++++++++----------------- 2 files changed, 97 insertions(+), 97 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 647c7e5f0..3776752f4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1150,6 +1150,33 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pattern, new_options, RLTrainer_source, flags = re.DOTALL ) + if trl_version >= Version("0.27.0"): + peft_pattern = ( + r"\s*if is_peft_available\(\) and is_peft_model\(model\) and args\.beta != 0\.0:" + r".*?" + r"param\.data = param\.data\.to\(torch\.bfloat16\)" + ) + + replacement_comment = "\n # PEFT initialization logic removed via script for trl >= 0.27.0\n" + + RLTrainer_source = re.sub( + peft_pattern, replacement_comment, RLTrainer_source, flags = re.DOTALL + ) + + elif trl_version >= Version("0.26.0"): + peft_block_pattern = ( + r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:" + r".*?" + r"param\.data = param\.data\.to\(torch\.bfloat16\)" + ) + + RLTrainer_source = re.sub( + peft_block_pattern, + "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n", + RLTrainer_source, + flags = re.DOTALL, + ) + if RLTrainer_name == "SFTTrainer": original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]' new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]' diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ce8339696..8208dc922 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -27,6 +27,7 @@ import inspect from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding from unsloth_zoo.utils import Version +from trl import __version__ as trl_version_raw from importlib.metadata import version as importlib_version from unsloth_zoo.log import logger from unsloth_zoo.device_type import device_synchronize @@ -57,6 +58,14 @@ torch_compile_options = { "triton.cudagraphs": False, } +try: + trl_version = Version(trl_version_raw) +except Exception: + try: + trl_version = Version(importlib_version("trl")) + except Exception: + trl_version = Version("0.0.0") + # Check untrained tokens def sft_trainer_fix_untrained_tokens(call_args, extra_args): @@ -434,99 +443,6 @@ def grpo_trainer__generate_and_score_completions(function_name, function): _target_line + _metadata_extraction, ) - # Unsloth: Skip prepare_multimodal_messages when prompts are pre-templated strings. - # When notebooks pre-apply apply_chat_template(), prompts become strings with image tokens - # already embedded. Calling prepare_multimodal_messages on strings crashes with TypeError. - # Skipping it keeps prompts as strings so TRL uses the non-conversational path, which - # ensures completions are strings and reward functions work correctly. - string_to_find_vision = """ if images is not None: - prompts = [ - prepare_multimodal_messages(prompt, image_list) - for prompt, image_list in zip(prompts, images, strict=True) - ]""" - - replacement_string_vision = """ if images is not None: - # Unsloth: skip prepare_multimodal_messages for pre-templated string prompts - if not prompts or not isinstance(prompts[0], str): - prompts = [ - prepare_multimodal_messages(prompt, image_list) - for prompt, image_list in zip(prompts, images, strict=True) - ]""" - - function = function.replace(string_to_find_vision, replacement_string_vision) - - # Unsloth: Skip apply_chat_template in the forward_kwargs block for pre-templated - # string prompts. When prompts are already strings (from notebooks that pre-applied - # apply_chat_template), calling it again crashes because strings aren't dicts. - # We use prompts directly as prompts_text instead. - - # TRL 0.26.2+ variant (has tools=self.tools) - string_to_find_fwd = """ if images is not None: - prompts_text = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs - )["prompt"] - for prompt in prompts - ]""" - - replacement_string_fwd = """ if images is not None: - # Unsloth: skip apply_chat_template for pre-templated string prompts - if prompts and isinstance(prompts[0], str): - prompts_text = prompts - else: - prompts_text = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs - )["prompt"] - for prompt in prompts - ]""" - - function = function.replace(string_to_find_fwd, replacement_string_fwd) - - # TRL 0.25.x variant (no tools parameter) - string_to_find_fwd_old = """ if images is not None: - prompts_text = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, **self.chat_template_kwargs - )["prompt"] - for prompt in prompts - ]""" - - replacement_string_fwd_old = """ if images is not None: - # Unsloth: skip apply_chat_template for pre-templated string prompts - if prompts and isinstance(prompts[0], str): - prompts_text = prompts - else: - prompts_text = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, **self.chat_template_kwargs - )["prompt"] - for prompt in prompts - ]""" - - function = function.replace(string_to_find_fwd_old, replacement_string_fwd_old) - - # TRL 0.25.1 single-line variant (no tools, single-line apply_chat_template call) - string_to_find_fwd_single = """ if images is not None: - prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] - for prompt in prompts - ]""" - - replacement_string_fwd_single = """ if images is not None: - # Unsloth: skip apply_chat_template for pre-templated string prompts - if prompts and isinstance(prompts[0], str): - prompts_text = prompts - else: - prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] - for prompt in prompts - ]""" - - function = function.replace( - string_to_find_fwd_single, replacement_string_fwd_single - ) - # This path is for TRL 0.24.0 images is a variable exclusive to this version string_to_find = """ if images is not None: output["num_images"] = num_images""" @@ -543,6 +459,17 @@ def grpo_trainer__generate_and_score_completions(function_name, function): function = function.replace(string_to_find, replacement_string) + if trl_version >= Version("0.25.0"): + # We replace the call using 'completions' with one using 'completions_text' + string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" + replacement_string = ( + " if images is not None:\n" + " rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)\n" + " else:\n" + " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" + ) + function = function.replace(string_to_find, replacement_string) + if "wake_up()" not in function: # Sleep functionality has been added to trl in v0.23.0. We do not want to redo this. # https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709 @@ -1072,7 +999,7 @@ def grpo_trainer_compute_loss(function_name, function): max_left_pad = inputs.get("max_left_pad", 0) if per_token_logps is not None: - loss, completion_length, mean_kl, delta, flat_is_ratio = ( + loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( grpo_compute_loss_slow( ref_logps, per_token_logps, @@ -1102,7 +1029,7 @@ def grpo_trainer_compute_loss(function_name, function): ) else: if hasattr(self.args, "loss_type"): - loss, completion_length, mean_kl, delta, flat_is_ratio = ( + loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( grpo_accumulated_loss( trainer = self, input_ids = _input_ids, @@ -1134,7 +1061,7 @@ def grpo_trainer_compute_loss(function_name, function): ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 - loss, completion_length, mean_kl = grpo_accumulated_loss( + loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, logits_to_keep = logits_to_keep, @@ -1149,7 +1076,6 @@ def grpo_trainer_compute_loss(function_name, function): logit_scale_divide = logit_scale_divide, attention_mask = attention_mask, ) - if "train" in self._metrics: mode = "eval" if self.control.should_evaluate else "train" self._metrics[mode]["completion_length"].append(completion_length.item()) @@ -1211,6 +1137,53 @@ def grpo_trainer_compute_loss(function_name, function): .item() ) + completion_token_count = completion_mask.sum().clamp(min = 1.0) + + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * completion_mask).sum() / completion_token_count + + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + + if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append( + gathered_low_clip.nanmean().item() + ) + self._metrics[mode]["clip_ratio/low_min"].append( + nanmin(gathered_low_clip).item() + ) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append( + gathered_high_clip.nanmean().item() + ) + self._metrics[mode]["clip_ratio/high_max"].append( + nanmax(gathered_high_clip).item() + ) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append( + gathered_clip_ratio.nanmean().item() + ) + elif self.loss_type == "cispo": + is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) + cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) + gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) + self._metrics[mode]["cispo_clip_ratio"].append( + gathered_cispo_clip_ratio.nanmean().item() + ) + return loss function = inspect.getsource(compute_loss)