unsloth/tests/python/test_dpo_vision_processor_passthrough.py
Datta Nimmaturi 4f9c8321a2
Fix DPO trainer multi process hang (#5199)
* Fix DPO trainer multi process hang

* Fix datacollator error

* further dpo vision changes

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Harden DPO vision row processing and source rewrites

- dpo_trainer_vision_signature_columns: also match TRL 0.22.x layout
  (image_sizes followed by ref_chosen_logps), so vision keys are not
  stripped via remove_unused_columns on the originally-affected version.
- dpo_trainer_concatenated_inputs: fall back to inserting after the
  image_sizes block when no token_type_ids anchor follows it.
- Apply the same vision model_kwargs forwarding rewrite to
  _compute_loss_liger via dpo_trainer_compute_loss_liger so the Liger DPO
  path does not drop pixel_position_ids/image_position_ids/
  mm_token_type_ids when args.use_liger_loss is true.
- dpo_trainer_vision_process_row:
  - guard chosen/rejected EOS append with tokenizer.eos_token_id is not None
  - use features.get("images") and features.get("prompt") to match the
    existing get on line 164 and avoid KeyError on rows without those keys
  - drop the torch.is_tensor gate so list-form pixel_position_ids/
    image_position_ids returned without return_tensors are still aliased
  - skip the loop entry for image_position_ids when it was already
    promoted to pixel_position_ids, so the output dict no longer carries
    both keys with identical data
- dpo_trainer_data_collator_vision_keys: switch from pad_sequence to
  trl.trainer.utils.pad with padding_side='left' (matches the DPO
  collator's prompt left-pad) and padding_value=-1 for *_position_ids
  keys (sentinel for padded patches), 0 otherwise. Skip the key when not
  every example carries it. Falls back to pad_sequence if trl.pad is
  unavailable or the tensor rank is too high.
- dpo_trainer_prepare_dataset: keep TRL's writer_batch_size=10 when
  popping num_proc; removing it defaults to 1000 and reintroduces the
  vision OOM risk that writer_batch_size=10 was set to avoid.

* DPO vision row: keep upstream-facing keys and fix patch padding

- dpo_trainer_vision_process_row: no longer aliases image_position_ids
  to pixel_position_ids. Each upstream-emitted vision key is forwarded
  under its own name. Gemma4 ForConditionalGeneration.forward accepts
  image_position_ids directly and renames it to pixel_position_ids only
  at the vision-tower call site, so aliasing in the row helper hid the
  kwarg the model actually consumes.
- dpo_trainer_vision_process_row: extract pixel_values via "in"
  membership instead of unconditional indexing. With the missing-images
  path returning [] to the processor, modern processors no longer emit
  a pixel_values key, and the previous indexing raised KeyError.
- dpo_trainer_data_collator_vision_keys: pick padding_side per key
  family. *_position_ids tensors are patch-aligned to pixel_values
  (TRL's DataCollatorForPreference right-pads pixel_values), so pad
  them right with the -1 sentinel; mm_token_type_ids is token-aligned
  to prompt_input_ids (left-padded by TRL), so pad it left with 0.

* DPO vision: handle multi-image prompts and arbitrary-rank collator pad

- dpo_trainer_vision_process_row: when a prompt is missing vision
  placeholders, insert one placeholder per missing image instead of
  always inserting a single token. Multi-image rows now satisfy the
  processor's token-vs-image count check rather than under-inserting
  and tripping the placeholder/feature mismatch.
- dpo_trainer_data_collator_vision_keys: drop the dim()<=2 gate around
  trl.trainer.utils.pad. trl.pad handles arbitrary rank correctly,
  while the previous fallback to torch.nn.utils.rnn.pad_sequence
  raised RuntimeError on rank-3 patch-position tensors with mismatched
  non-leading dimensions. The pad_sequence path remains as a degraded
  fallback only when trl.pad is unavailable or raises.

* DPO vision row: support scalar images and align prompt-aligned aux ids

- dpo_trainer_vision_process_row: type-aware normalization of the
  features['images'] column instead of a truthiness/len check that
  raised on single image objects (PIL.Image has no __len__) and on
  numpy ndarrays (truthiness ambiguous). Lists/tuples count as their
  length, scalar image objects count as one, None counts as zero, and
  the original value is forwarded to the processor.
- dpo_trainer_vision_process_row: when max_prompt_length truncates
  prompt_input_ids, also slice token_type_ids and mm_token_type_ids
  by the same [-max_prompt_length:] suffix. Those keys are 1:1 token
  aligned to prompt_input_ids (Gemma 4 vision attention keys off
  mm_token_type_ids per modular_gemma4.py), so leaving them at the
  original length silently misaligned the multimodal mask.

* DPO vision row: stop synthesizing vision-token placeholders

Pass features['prompt'] and features['images'] straight to the
processor without inserting any extra placeholder tokens. The previous
helper used processing_class.image_token, which is the right prompt
placeholder for Gemma 4 but the wrong one for Gemma 3 (whose prompt
placeholder is boi_token while image_token is the inner expansion
target). Synthesizing that token also broke multi-image rows: text
ended up with N placeholders while the row helper only forwarded the
first image's pixel_values via the standard [0] indexing that mirrors
upstream TRL process_row, so token vs image-feature counts diverged.
Removing the synthesis matches stock TRL behavior; users provide the
correct placeholders for their processor in the prompt.

* Add tests for DPO vision row processor passthrough

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lee Jackson <130007945+Imagineer99@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-29 04:15:34 -07:00

149 lines
3.9 KiB
Python

"""Verify dpo_trainer_vision_process_row forwards prompt and images verbatim."""
import ast
import os
import numpy as np
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
RL_PATH = os.path.join(REPO_ROOT, "unsloth", "models", "rl_replacements.py")
def _load_helpers():
src = open(RL_PATH).read()
tree = ast.parse(src)
import torch as _torch
ns = {"torch": _torch}
for node in tree.body:
if isinstance(node, ast.Assign) and any(
isinstance(t, ast.Name) and t.id == "_DPO_VISION_KEYS" for t in node.targets
):
exec(ast.get_source_segment(src, node), ns)
for node in tree.body:
if isinstance(node, ast.FunctionDef) and node.name.startswith(
("dpo_trainer_", "_dpo_trainer_")
):
exec(ast.get_source_segment(src, node), ns)
return ns
class _Tok:
eos_token_id = 99
bos_token_id = None
def __call__(self, t, add_special_tokens = False):
return {"input_ids": [10]}
class _Capture:
image_token = "<img>"
boi_token = "<boi>"
def __init__(self):
self.tokenizer = _Tok()
self.last_text = None
self.last_images = "__sentinel__"
def __call__(self, images = None, text = None, add_special_tokens = False):
self.last_text = text
self.last_images = images
out = {"input_ids": [[1, 2]]}
if images is not None:
out["pixel_values"] = [object()]
return out
def test_prompt_passes_through_without_image_token_synthesis():
ns = _load_helpers()
proc = _Capture()
ns["dpo_trainer_vision_process_row"](
{"prompt": "describe", "chosen": "c", "rejected": "r", "images": ["i"]},
proc,
)
assert proc.last_text == "describe"
def test_prompt_with_existing_image_token_unchanged():
ns = _load_helpers()
proc = _Capture()
ns["dpo_trainer_vision_process_row"](
{"prompt": "<img> describe", "chosen": "c", "rejected": "r", "images": ["i"]},
proc,
)
assert proc.last_text == "<img> describe"
def test_gemma3_style_boi_token_prompt_not_corrupted():
ns = _load_helpers()
proc = _Capture()
ns["dpo_trainer_vision_process_row"](
{"prompt": "<boi> describe", "chosen": "c", "rejected": "r", "images": ["i"]},
proc,
)
assert proc.last_text == "<boi> describe"
assert "<img>" not in proc.last_text
def test_multi_image_prompt_unchanged_no_extra_placeholders():
ns = _load_helpers()
proc = _Capture()
ns["dpo_trainer_vision_process_row"](
{
"prompt": "compare",
"chosen": "c",
"rejected": "r",
"images": ["a", "b", "c"],
},
proc,
)
assert proc.last_text == "compare"
def test_list_images_forwarded_verbatim():
ns = _load_helpers()
proc = _Capture()
payload = ["a", "b"]
ns["dpo_trainer_vision_process_row"](
{"prompt": "p", "chosen": "c", "rejected": "r", "images": payload},
proc,
)
assert proc.last_images is payload
def test_single_pil_like_image_forwarded_verbatim():
ns = _load_helpers()
class PIL:
def __bool__(self):
return True
proc = _Capture()
pil = PIL()
ns["dpo_trainer_vision_process_row"](
{"prompt": "p", "chosen": "c", "rejected": "r", "images": pil},
proc,
)
assert proc.last_images is pil
def test_numpy_ndarray_image_forwarded_verbatim():
ns = _load_helpers()
proc = _Capture()
arr = np.zeros((2, 3, 3), dtype = np.uint8)
ns["dpo_trainer_vision_process_row"](
{"prompt": "p", "chosen": "c", "rejected": "r", "images": arr},
proc,
)
assert proc.last_images is arr
def test_missing_images_key_passes_none_to_processor():
ns = _load_helpers()
proc = _Capture()
ns["dpo_trainer_vision_process_row"](
{"prompt": "p", "chosen": "c", "rejected": "r"},
proc,
)
assert proc.last_images is None