Trl 0.27.0 update (#3965)

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update rl_replacements.py

* Update rl.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update rl_replacements.py, remove chat template from codexes commits

* Update rl.py, got rid of gradient checkpointing code that did not work

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
pluesclues 2026-02-05 02:01:16 -05:00 committed by GitHub
parent e1c682e6d2
commit 9b34982509
2 changed files with 97 additions and 97 deletions

View file

@ -1150,6 +1150,33 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
pattern, new_options, RLTrainer_source, flags = re.DOTALL
)
if trl_version >= Version("0.27.0"):
peft_pattern = (
r"\s*if is_peft_available\(\) and is_peft_model\(model\) and args\.beta != 0\.0:"
r".*?"
r"param\.data = param\.data\.to\(torch\.bfloat16\)"
)
replacement_comment = "\n # PEFT initialization logic removed via script for trl >= 0.27.0\n"
RLTrainer_source = re.sub(
peft_pattern, replacement_comment, RLTrainer_source, flags = re.DOTALL
)
elif trl_version >= Version("0.26.0"):
peft_block_pattern = (
r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:"
r".*?"
r"param\.data = param\.data\.to\(torch\.bfloat16\)"
)
RLTrainer_source = re.sub(
peft_block_pattern,
"\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n",
RLTrainer_source,
flags = re.DOTALL,
)
if RLTrainer_name == "SFTTrainer":
original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]'
new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]'

View file

@ -27,6 +27,7 @@ import inspect
from collections import defaultdict
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
from unsloth_zoo.utils import Version
from trl import __version__ as trl_version_raw
from importlib.metadata import version as importlib_version
from unsloth_zoo.log import logger
from unsloth_zoo.device_type import device_synchronize
@ -57,6 +58,14 @@ torch_compile_options = {
"triton.cudagraphs": False,
}
try:
trl_version = Version(trl_version_raw)
except Exception:
try:
trl_version = Version(importlib_version("trl"))
except Exception:
trl_version = Version("0.0.0")
# Check untrained tokens
def sft_trainer_fix_untrained_tokens(call_args, extra_args):
@ -434,99 +443,6 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
_target_line + _metadata_extraction,
)
# Unsloth: Skip prepare_multimodal_messages when prompts are pre-templated strings.
# When notebooks pre-apply apply_chat_template(), prompts become strings with image tokens
# already embedded. Calling prepare_multimodal_messages on strings crashes with TypeError.
# Skipping it keeps prompts as strings so TRL uses the non-conversational path, which
# ensures completions are strings and reward functions work correctly.
string_to_find_vision = """ if images is not None:
prompts = [
prepare_multimodal_messages(prompt, image_list)
for prompt, image_list in zip(prompts, images, strict=True)
]"""
replacement_string_vision = """ if images is not None:
# Unsloth: skip prepare_multimodal_messages for pre-templated string prompts
if not prompts or not isinstance(prompts[0], str):
prompts = [
prepare_multimodal_messages(prompt, image_list)
for prompt, image_list in zip(prompts, images, strict=True)
]"""
function = function.replace(string_to_find_vision, replacement_string_vision)
# Unsloth: Skip apply_chat_template in the forward_kwargs block for pre-templated
# string prompts. When prompts are already strings (from notebooks that pre-applied
# apply_chat_template), calling it again crashes because strings aren't dicts.
# We use prompts directly as prompts_text instead.
# TRL 0.26.2+ variant (has tools=self.tools)
string_to_find_fwd = """ if images is not None:
prompts_text = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
)["prompt"]
for prompt in prompts
]"""
replacement_string_fwd = """ if images is not None:
# Unsloth: skip apply_chat_template for pre-templated string prompts
if prompts and isinstance(prompts[0], str):
prompts_text = prompts
else:
prompts_text = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
)["prompt"]
for prompt in prompts
]"""
function = function.replace(string_to_find_fwd, replacement_string_fwd)
# TRL 0.25.x variant (no tools parameter)
string_to_find_fwd_old = """ if images is not None:
prompts_text = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, **self.chat_template_kwargs
)["prompt"]
for prompt in prompts
]"""
replacement_string_fwd_old = """ if images is not None:
# Unsloth: skip apply_chat_template for pre-templated string prompts
if prompts and isinstance(prompts[0], str):
prompts_text = prompts
else:
prompts_text = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, **self.chat_template_kwargs
)["prompt"]
for prompt in prompts
]"""
function = function.replace(string_to_find_fwd_old, replacement_string_fwd_old)
# TRL 0.25.1 single-line variant (no tools, single-line apply_chat_template call)
string_to_find_fwd_single = """ if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
for prompt in prompts
]"""
replacement_string_fwd_single = """ if images is not None:
# Unsloth: skip apply_chat_template for pre-templated string prompts
if prompts and isinstance(prompts[0], str):
prompts_text = prompts
else:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
for prompt in prompts
]"""
function = function.replace(
string_to_find_fwd_single, replacement_string_fwd_single
)
# This path is for TRL 0.24.0 images is a variable exclusive to this version
string_to_find = """ if images is not None:
output["num_images"] = num_images"""
@ -543,6 +459,17 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
function = function.replace(string_to_find, replacement_string)
if trl_version >= Version("0.25.0"):
# We replace the call using 'completions' with one using 'completions_text'
string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)"
replacement_string = (
" if images is not None:\n"
" rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)\n"
" else:\n"
" rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)"
)
function = function.replace(string_to_find, replacement_string)
if "wake_up()" not in function:
# Sleep functionality has been added to trl in v0.23.0. We do not want to redo this.
# https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709
@ -1072,7 +999,7 @@ def grpo_trainer_compute_loss(function_name, function):
max_left_pad = inputs.get("max_left_pad", 0)
if per_token_logps is not None:
loss, completion_length, mean_kl, delta, flat_is_ratio = (
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = (
grpo_compute_loss_slow(
ref_logps,
per_token_logps,
@ -1102,7 +1029,7 @@ def grpo_trainer_compute_loss(function_name, function):
)
else:
if hasattr(self.args, "loss_type"):
loss, completion_length, mean_kl, delta, flat_is_ratio = (
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = (
grpo_accumulated_loss(
trainer = self,
input_ids = _input_ids,
@ -1134,7 +1061,7 @@ def grpo_trainer_compute_loss(function_name, function):
)
else:
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
loss, completion_length, mean_kl = grpo_accumulated_loss(
loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss(
trainer = self,
input_ids = _input_ids,
logits_to_keep = logits_to_keep,
@ -1149,7 +1076,6 @@ def grpo_trainer_compute_loss(function_name, function):
logit_scale_divide = logit_scale_divide,
attention_mask = attention_mask,
)
if "train" in self._metrics:
mode = "eval" if self.control.should_evaluate else "train"
self._metrics[mode]["completion_length"].append(completion_length.item())
@ -1211,6 +1137,53 @@ def grpo_trainer_compute_loss(function_name, function):
.item()
)
completion_token_count = completion_mask.sum().clamp(min = 1.0)
def masked_batch_mean(x):
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
return x.mean()
else:
return (x * completion_mask).sum() / completion_token_count
if advantages.dim() == 1:
advantages = advantages.unsqueeze(1)
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
is_region_clipped = is_low_clipped | is_high_clipped
low_clip = masked_batch_mean(is_low_clipped.float())
high_clip = masked_batch_mean(is_high_clipped.float())
clip_ratio = masked_batch_mean(is_region_clipped.float())
gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(
gathered_low_clip.nanmean().item()
)
self._metrics[mode]["clip_ratio/low_min"].append(
nanmin(gathered_low_clip).item()
)
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(
gathered_high_clip.nanmean().item()
)
self._metrics[mode]["clip_ratio/high_max"].append(
nanmax(gathered_high_clip).item()
)
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(
gathered_clip_ratio.nanmean().item()
)
elif self.loss_type == "cispo":
is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0)
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
self._metrics[mode]["cispo_clip_ratio"].append(
gathered_cispo_clip_ratio.nanmean().item()
)
return loss
function = inspect.getsource(compute_loss)