mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-22 02:50:03 +00:00
Trl 0.27.0 update (#3965)
* Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update rl_replacements.py * Update rl.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update rl_replacements.py, remove chat template from codexes commits * Update rl.py, got rid of gradient checkpointing code that did not work --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e1c682e6d2
commit
9b34982509
2 changed files with 97 additions and 97 deletions
|
|
@ -1150,6 +1150,33 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
pattern, new_options, RLTrainer_source, flags = re.DOTALL
|
||||
)
|
||||
|
||||
if trl_version >= Version("0.27.0"):
|
||||
peft_pattern = (
|
||||
r"\s*if is_peft_available\(\) and is_peft_model\(model\) and args\.beta != 0\.0:"
|
||||
r".*?"
|
||||
r"param\.data = param\.data\.to\(torch\.bfloat16\)"
|
||||
)
|
||||
|
||||
replacement_comment = "\n # PEFT initialization logic removed via script for trl >= 0.27.0\n"
|
||||
|
||||
RLTrainer_source = re.sub(
|
||||
peft_pattern, replacement_comment, RLTrainer_source, flags = re.DOTALL
|
||||
)
|
||||
|
||||
elif trl_version >= Version("0.26.0"):
|
||||
peft_block_pattern = (
|
||||
r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:"
|
||||
r".*?"
|
||||
r"param\.data = param\.data\.to\(torch\.bfloat16\)"
|
||||
)
|
||||
|
||||
RLTrainer_source = re.sub(
|
||||
peft_block_pattern,
|
||||
"\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n",
|
||||
RLTrainer_source,
|
||||
flags = re.DOTALL,
|
||||
)
|
||||
|
||||
if RLTrainer_name == "SFTTrainer":
|
||||
original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]'
|
||||
new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]'
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import inspect
|
|||
from collections import defaultdict
|
||||
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
|
||||
from unsloth_zoo.utils import Version
|
||||
from trl import __version__ as trl_version_raw
|
||||
from importlib.metadata import version as importlib_version
|
||||
from unsloth_zoo.log import logger
|
||||
from unsloth_zoo.device_type import device_synchronize
|
||||
|
|
@ -57,6 +58,14 @@ torch_compile_options = {
|
|||
"triton.cudagraphs": False,
|
||||
}
|
||||
|
||||
try:
|
||||
trl_version = Version(trl_version_raw)
|
||||
except Exception:
|
||||
try:
|
||||
trl_version = Version(importlib_version("trl"))
|
||||
except Exception:
|
||||
trl_version = Version("0.0.0")
|
||||
|
||||
|
||||
# Check untrained tokens
|
||||
def sft_trainer_fix_untrained_tokens(call_args, extra_args):
|
||||
|
|
@ -434,99 +443,6 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
|
|||
_target_line + _metadata_extraction,
|
||||
)
|
||||
|
||||
# Unsloth: Skip prepare_multimodal_messages when prompts are pre-templated strings.
|
||||
# When notebooks pre-apply apply_chat_template(), prompts become strings with image tokens
|
||||
# already embedded. Calling prepare_multimodal_messages on strings crashes with TypeError.
|
||||
# Skipping it keeps prompts as strings so TRL uses the non-conversational path, which
|
||||
# ensures completions are strings and reward functions work correctly.
|
||||
string_to_find_vision = """ if images is not None:
|
||||
prompts = [
|
||||
prepare_multimodal_messages(prompt, image_list)
|
||||
for prompt, image_list in zip(prompts, images, strict=True)
|
||||
]"""
|
||||
|
||||
replacement_string_vision = """ if images is not None:
|
||||
# Unsloth: skip prepare_multimodal_messages for pre-templated string prompts
|
||||
if not prompts or not isinstance(prompts[0], str):
|
||||
prompts = [
|
||||
prepare_multimodal_messages(prompt, image_list)
|
||||
for prompt, image_list in zip(prompts, images, strict=True)
|
||||
]"""
|
||||
|
||||
function = function.replace(string_to_find_vision, replacement_string_vision)
|
||||
|
||||
# Unsloth: Skip apply_chat_template in the forward_kwargs block for pre-templated
|
||||
# string prompts. When prompts are already strings (from notebooks that pre-applied
|
||||
# apply_chat_template), calling it again crashes because strings aren't dicts.
|
||||
# We use prompts directly as prompts_text instead.
|
||||
|
||||
# TRL 0.26.2+ variant (has tools=self.tools)
|
||||
string_to_find_fwd = """ if images is not None:
|
||||
prompts_text = [
|
||||
apply_chat_template(
|
||||
{"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
|
||||
)["prompt"]
|
||||
for prompt in prompts
|
||||
]"""
|
||||
|
||||
replacement_string_fwd = """ if images is not None:
|
||||
# Unsloth: skip apply_chat_template for pre-templated string prompts
|
||||
if prompts and isinstance(prompts[0], str):
|
||||
prompts_text = prompts
|
||||
else:
|
||||
prompts_text = [
|
||||
apply_chat_template(
|
||||
{"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
|
||||
)["prompt"]
|
||||
for prompt in prompts
|
||||
]"""
|
||||
|
||||
function = function.replace(string_to_find_fwd, replacement_string_fwd)
|
||||
|
||||
# TRL 0.25.x variant (no tools parameter)
|
||||
string_to_find_fwd_old = """ if images is not None:
|
||||
prompts_text = [
|
||||
apply_chat_template(
|
||||
{"prompt": prompt}, self.processing_class, **self.chat_template_kwargs
|
||||
)["prompt"]
|
||||
for prompt in prompts
|
||||
]"""
|
||||
|
||||
replacement_string_fwd_old = """ if images is not None:
|
||||
# Unsloth: skip apply_chat_template for pre-templated string prompts
|
||||
if prompts and isinstance(prompts[0], str):
|
||||
prompts_text = prompts
|
||||
else:
|
||||
prompts_text = [
|
||||
apply_chat_template(
|
||||
{"prompt": prompt}, self.processing_class, **self.chat_template_kwargs
|
||||
)["prompt"]
|
||||
for prompt in prompts
|
||||
]"""
|
||||
|
||||
function = function.replace(string_to_find_fwd_old, replacement_string_fwd_old)
|
||||
|
||||
# TRL 0.25.1 single-line variant (no tools, single-line apply_chat_template call)
|
||||
string_to_find_fwd_single = """ if images is not None:
|
||||
prompts_text = [
|
||||
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
|
||||
for prompt in prompts
|
||||
]"""
|
||||
|
||||
replacement_string_fwd_single = """ if images is not None:
|
||||
# Unsloth: skip apply_chat_template for pre-templated string prompts
|
||||
if prompts and isinstance(prompts[0], str):
|
||||
prompts_text = prompts
|
||||
else:
|
||||
prompts_text = [
|
||||
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
|
||||
for prompt in prompts
|
||||
]"""
|
||||
|
||||
function = function.replace(
|
||||
string_to_find_fwd_single, replacement_string_fwd_single
|
||||
)
|
||||
|
||||
# This path is for TRL 0.24.0 images is a variable exclusive to this version
|
||||
string_to_find = """ if images is not None:
|
||||
output["num_images"] = num_images"""
|
||||
|
|
@ -543,6 +459,17 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
|
|||
|
||||
function = function.replace(string_to_find, replacement_string)
|
||||
|
||||
if trl_version >= Version("0.25.0"):
|
||||
# We replace the call using 'completions' with one using 'completions_text'
|
||||
string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)"
|
||||
replacement_string = (
|
||||
" if images is not None:\n"
|
||||
" rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)\n"
|
||||
" else:\n"
|
||||
" rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)"
|
||||
)
|
||||
function = function.replace(string_to_find, replacement_string)
|
||||
|
||||
if "wake_up()" not in function:
|
||||
# Sleep functionality has been added to trl in v0.23.0. We do not want to redo this.
|
||||
# https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709
|
||||
|
|
@ -1072,7 +999,7 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
|
||||
max_left_pad = inputs.get("max_left_pad", 0)
|
||||
if per_token_logps is not None:
|
||||
loss, completion_length, mean_kl, delta, flat_is_ratio = (
|
||||
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = (
|
||||
grpo_compute_loss_slow(
|
||||
ref_logps,
|
||||
per_token_logps,
|
||||
|
|
@ -1102,7 +1029,7 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
)
|
||||
else:
|
||||
if hasattr(self.args, "loss_type"):
|
||||
loss, completion_length, mean_kl, delta, flat_is_ratio = (
|
||||
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = (
|
||||
grpo_accumulated_loss(
|
||||
trainer = self,
|
||||
input_ids = _input_ids,
|
||||
|
|
@ -1134,7 +1061,7 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
)
|
||||
else:
|
||||
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
|
||||
loss, completion_length, mean_kl = grpo_accumulated_loss(
|
||||
loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss(
|
||||
trainer = self,
|
||||
input_ids = _input_ids,
|
||||
logits_to_keep = logits_to_keep,
|
||||
|
|
@ -1149,7 +1076,6 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
logit_scale_divide = logit_scale_divide,
|
||||
attention_mask = attention_mask,
|
||||
)
|
||||
|
||||
if "train" in self._metrics:
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
self._metrics[mode]["completion_length"].append(completion_length.item())
|
||||
|
|
@ -1211,6 +1137,53 @@ def grpo_trainer_compute_loss(function_name, function):
|
|||
.item()
|
||||
)
|
||||
|
||||
completion_token_count = completion_mask.sum().clamp(min = 1.0)
|
||||
|
||||
def masked_batch_mean(x):
|
||||
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
|
||||
return x.mean()
|
||||
else:
|
||||
return (x * completion_mask).sum() / completion_token_count
|
||||
|
||||
if advantages.dim() == 1:
|
||||
advantages = advantages.unsqueeze(1)
|
||||
|
||||
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
|
||||
# Compute the clipped probability ratios
|
||||
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
|
||||
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
|
||||
is_region_clipped = is_low_clipped | is_high_clipped
|
||||
|
||||
low_clip = masked_batch_mean(is_low_clipped.float())
|
||||
high_clip = masked_batch_mean(is_high_clipped.float())
|
||||
clip_ratio = masked_batch_mean(is_region_clipped.float())
|
||||
|
||||
gathered_low_clip = self.accelerator.gather(low_clip)
|
||||
self._metrics[mode]["clip_ratio/low_mean"].append(
|
||||
gathered_low_clip.nanmean().item()
|
||||
)
|
||||
self._metrics[mode]["clip_ratio/low_min"].append(
|
||||
nanmin(gathered_low_clip).item()
|
||||
)
|
||||
gathered_high_clip = self.accelerator.gather(high_clip)
|
||||
self._metrics[mode]["clip_ratio/high_mean"].append(
|
||||
gathered_high_clip.nanmean().item()
|
||||
)
|
||||
self._metrics[mode]["clip_ratio/high_max"].append(
|
||||
nanmax(gathered_high_clip).item()
|
||||
)
|
||||
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
|
||||
self._metrics[mode]["clip_ratio/region_mean"].append(
|
||||
gathered_clip_ratio.nanmean().item()
|
||||
)
|
||||
elif self.loss_type == "cispo":
|
||||
is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0)
|
||||
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
|
||||
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
|
||||
self._metrics[mode]["cispo_clip_ratio"].append(
|
||||
gathered_cispo_clip_ratio.nanmean().item()
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
function = inspect.getsource(compute_loss)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue