mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-17 21:14:06 +00:00
* Fix DPO trainer multi process hang * Fix datacollator error * further dpo vision changes * cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Harden DPO vision row processing and source rewrites - dpo_trainer_vision_signature_columns: also match TRL 0.22.x layout (image_sizes followed by ref_chosen_logps), so vision keys are not stripped via remove_unused_columns on the originally-affected version. - dpo_trainer_concatenated_inputs: fall back to inserting after the image_sizes block when no token_type_ids anchor follows it. - Apply the same vision model_kwargs forwarding rewrite to _compute_loss_liger via dpo_trainer_compute_loss_liger so the Liger DPO path does not drop pixel_position_ids/image_position_ids/ mm_token_type_ids when args.use_liger_loss is true. - dpo_trainer_vision_process_row: - guard chosen/rejected EOS append with tokenizer.eos_token_id is not None - use features.get("images") and features.get("prompt") to match the existing get on line 164 and avoid KeyError on rows without those keys - drop the torch.is_tensor gate so list-form pixel_position_ids/ image_position_ids returned without return_tensors are still aliased - skip the loop entry for image_position_ids when it was already promoted to pixel_position_ids, so the output dict no longer carries both keys with identical data - dpo_trainer_data_collator_vision_keys: switch from pad_sequence to trl.trainer.utils.pad with padding_side='left' (matches the DPO collator's prompt left-pad) and padding_value=-1 for *_position_ids keys (sentinel for padded patches), 0 otherwise. Skip the key when not every example carries it. Falls back to pad_sequence if trl.pad is unavailable or the tensor rank is too high. - dpo_trainer_prepare_dataset: keep TRL's writer_batch_size=10 when popping num_proc; removing it defaults to 1000 and reintroduces the vision OOM risk that writer_batch_size=10 was set to avoid. * DPO vision row: keep upstream-facing keys and fix patch padding - dpo_trainer_vision_process_row: no longer aliases image_position_ids to pixel_position_ids. Each upstream-emitted vision key is forwarded under its own name. Gemma4 ForConditionalGeneration.forward accepts image_position_ids directly and renames it to pixel_position_ids only at the vision-tower call site, so aliasing in the row helper hid the kwarg the model actually consumes. - dpo_trainer_vision_process_row: extract pixel_values via "in" membership instead of unconditional indexing. With the missing-images path returning [] to the processor, modern processors no longer emit a pixel_values key, and the previous indexing raised KeyError. - dpo_trainer_data_collator_vision_keys: pick padding_side per key family. *_position_ids tensors are patch-aligned to pixel_values (TRL's DataCollatorForPreference right-pads pixel_values), so pad them right with the -1 sentinel; mm_token_type_ids is token-aligned to prompt_input_ids (left-padded by TRL), so pad it left with 0. * DPO vision: handle multi-image prompts and arbitrary-rank collator pad - dpo_trainer_vision_process_row: when a prompt is missing vision placeholders, insert one placeholder per missing image instead of always inserting a single token. Multi-image rows now satisfy the processor's token-vs-image count check rather than under-inserting and tripping the placeholder/feature mismatch. - dpo_trainer_data_collator_vision_keys: drop the dim()<=2 gate around trl.trainer.utils.pad. trl.pad handles arbitrary rank correctly, while the previous fallback to torch.nn.utils.rnn.pad_sequence raised RuntimeError on rank-3 patch-position tensors with mismatched non-leading dimensions. The pad_sequence path remains as a degraded fallback only when trl.pad is unavailable or raises. * DPO vision row: support scalar images and align prompt-aligned aux ids - dpo_trainer_vision_process_row: type-aware normalization of the features['images'] column instead of a truthiness/len check that raised on single image objects (PIL.Image has no __len__) and on numpy ndarrays (truthiness ambiguous). Lists/tuples count as their length, scalar image objects count as one, None counts as zero, and the original value is forwarded to the processor. - dpo_trainer_vision_process_row: when max_prompt_length truncates prompt_input_ids, also slice token_type_ids and mm_token_type_ids by the same [-max_prompt_length:] suffix. Those keys are 1:1 token aligned to prompt_input_ids (Gemma 4 vision attention keys off mm_token_type_ids per modular_gemma4.py), so leaving them at the original length silently misaligned the multimodal mask. * DPO vision row: stop synthesizing vision-token placeholders Pass features['prompt'] and features['images'] straight to the processor without inserting any extra placeholder tokens. The previous helper used processing_class.image_token, which is the right prompt placeholder for Gemma 4 but the wrong one for Gemma 3 (whose prompt placeholder is boi_token while image_token is the inner expansion target). Synthesizing that token also broke multi-image rows: text ended up with N placeholders while the row helper only forwarded the first image's pixel_values via the standard [0] indexing that mirrors upstream TRL process_row, so token vs image-feature counts diverged. Removing the synthesis matches stock TRL behavior; users provide the correct placeholders for their processor in the prompt. * Add tests for DPO vision row processor passthrough * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
149 lines
3.9 KiB
Python
149 lines
3.9 KiB
Python
"""Verify dpo_trainer_vision_process_row forwards prompt and images verbatim."""
|
|
|
|
import ast
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
|
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
RL_PATH = os.path.join(REPO_ROOT, "unsloth", "models", "rl_replacements.py")
|
|
|
|
|
|
def _load_helpers():
|
|
src = open(RL_PATH).read()
|
|
tree = ast.parse(src)
|
|
import torch as _torch
|
|
|
|
ns = {"torch": _torch}
|
|
for node in tree.body:
|
|
if isinstance(node, ast.Assign) and any(
|
|
isinstance(t, ast.Name) and t.id == "_DPO_VISION_KEYS" for t in node.targets
|
|
):
|
|
exec(ast.get_source_segment(src, node), ns)
|
|
for node in tree.body:
|
|
if isinstance(node, ast.FunctionDef) and node.name.startswith(
|
|
("dpo_trainer_", "_dpo_trainer_")
|
|
):
|
|
exec(ast.get_source_segment(src, node), ns)
|
|
return ns
|
|
|
|
|
|
class _Tok:
|
|
eos_token_id = 99
|
|
bos_token_id = None
|
|
|
|
def __call__(self, t, add_special_tokens = False):
|
|
return {"input_ids": [10]}
|
|
|
|
|
|
class _Capture:
|
|
image_token = "<img>"
|
|
boi_token = "<boi>"
|
|
|
|
def __init__(self):
|
|
self.tokenizer = _Tok()
|
|
self.last_text = None
|
|
self.last_images = "__sentinel__"
|
|
|
|
def __call__(self, images = None, text = None, add_special_tokens = False):
|
|
self.last_text = text
|
|
self.last_images = images
|
|
out = {"input_ids": [[1, 2]]}
|
|
if images is not None:
|
|
out["pixel_values"] = [object()]
|
|
return out
|
|
|
|
|
|
def test_prompt_passes_through_without_image_token_synthesis():
|
|
ns = _load_helpers()
|
|
proc = _Capture()
|
|
ns["dpo_trainer_vision_process_row"](
|
|
{"prompt": "describe", "chosen": "c", "rejected": "r", "images": ["i"]},
|
|
proc,
|
|
)
|
|
assert proc.last_text == "describe"
|
|
|
|
|
|
def test_prompt_with_existing_image_token_unchanged():
|
|
ns = _load_helpers()
|
|
proc = _Capture()
|
|
ns["dpo_trainer_vision_process_row"](
|
|
{"prompt": "<img> describe", "chosen": "c", "rejected": "r", "images": ["i"]},
|
|
proc,
|
|
)
|
|
assert proc.last_text == "<img> describe"
|
|
|
|
|
|
def test_gemma3_style_boi_token_prompt_not_corrupted():
|
|
ns = _load_helpers()
|
|
proc = _Capture()
|
|
ns["dpo_trainer_vision_process_row"](
|
|
{"prompt": "<boi> describe", "chosen": "c", "rejected": "r", "images": ["i"]},
|
|
proc,
|
|
)
|
|
assert proc.last_text == "<boi> describe"
|
|
assert "<img>" not in proc.last_text
|
|
|
|
|
|
def test_multi_image_prompt_unchanged_no_extra_placeholders():
|
|
ns = _load_helpers()
|
|
proc = _Capture()
|
|
ns["dpo_trainer_vision_process_row"](
|
|
{
|
|
"prompt": "compare",
|
|
"chosen": "c",
|
|
"rejected": "r",
|
|
"images": ["a", "b", "c"],
|
|
},
|
|
proc,
|
|
)
|
|
assert proc.last_text == "compare"
|
|
|
|
|
|
def test_list_images_forwarded_verbatim():
|
|
ns = _load_helpers()
|
|
proc = _Capture()
|
|
payload = ["a", "b"]
|
|
ns["dpo_trainer_vision_process_row"](
|
|
{"prompt": "p", "chosen": "c", "rejected": "r", "images": payload},
|
|
proc,
|
|
)
|
|
assert proc.last_images is payload
|
|
|
|
|
|
def test_single_pil_like_image_forwarded_verbatim():
|
|
ns = _load_helpers()
|
|
|
|
class PIL:
|
|
def __bool__(self):
|
|
return True
|
|
|
|
proc = _Capture()
|
|
pil = PIL()
|
|
ns["dpo_trainer_vision_process_row"](
|
|
{"prompt": "p", "chosen": "c", "rejected": "r", "images": pil},
|
|
proc,
|
|
)
|
|
assert proc.last_images is pil
|
|
|
|
|
|
def test_numpy_ndarray_image_forwarded_verbatim():
|
|
ns = _load_helpers()
|
|
proc = _Capture()
|
|
arr = np.zeros((2, 3, 3), dtype = np.uint8)
|
|
ns["dpo_trainer_vision_process_row"](
|
|
{"prompt": "p", "chosen": "c", "rejected": "r", "images": arr},
|
|
proc,
|
|
)
|
|
assert proc.last_images is arr
|
|
|
|
|
|
def test_missing_images_key_passes_none_to_processor():
|
|
ns = _load_helpers()
|
|
proc = _Capture()
|
|
ns["dpo_trainer_vision_process_row"](
|
|
{"prompt": "p", "chosen": "c", "rejected": "r"},
|
|
proc,
|
|
)
|
|
assert proc.last_images is None
|