From f9682e656c2f0cb9c6c301e6e523d6da8613ed01 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 22 Apr 2026 21:34:08 +0530 Subject: [PATCH] update gema4 chat templates (#5116) * update gema4 chat templates * udpate template * update template for gemma4 * Add gemma4 chat template tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Daniel Han Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/test_gemma4_chat_template.py | 183 +++++++++++++++++++++++++++++ unsloth/chat_templates.py | 95 +++++++++++---- 2 files changed, 255 insertions(+), 23 deletions(-) create mode 100644 tests/test_gemma4_chat_template.py diff --git a/tests/test_gemma4_chat_template.py b/tests/test_gemma4_chat_template.py new file mode 100644 index 000000000..6211b4d43 --- /dev/null +++ b/tests/test_gemma4_chat_template.py @@ -0,0 +1,183 @@ +import os +import re + +import pytest +from jinja2 import Environment, StrictUndefined +from jinja2.exceptions import TemplateError + + +CHAT_TEMPLATES_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "unsloth", + "chat_templates.py", +) + + +def _extract_template(name): + src = open(CHAT_TEMPLATES_PATH).read() + pattern = rf'{re.escape(name)}\s*=\s*\\\n"""(.*?)"""' + m = re.search(pattern, src, flags = re.DOTALL) + assert m, f"Could not extract {name} from chat_templates.py" + return m.group(1) + + +def _env(): + env = Environment(undefined = StrictUndefined, trim_blocks = False, lstrip_blocks = False) + env.globals["raise_exception"] = lambda msg: (_ for _ in ()).throw( + TemplateError(msg) + ) + return env + + +def _render(template_name, messages, **kwargs): + src = _extract_template(template_name) + tmpl = _env().from_string(src) + ctx = {"messages": messages, "add_generation_prompt": False} + ctx.update(kwargs) + return tmpl.render(**ctx) + + +# ---------- system turn and <|think|> placement ---------- + + +def test_system_message_emits_dedicated_system_turn(): + msgs = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"}, + ] + out = _render("gemma4_template", msgs) + assert "<|turn>system\nYou are helpful" in out + assert "<|turn>user\nHi" in out + assert "You are helpful\n\nHi" not in out + + +def test_developer_role_treated_as_system(): + msgs = [ + {"role": "developer", "content": "Internal instructions"}, + {"role": "user", "content": "Hi"}, + ] + out = _render("gemma4_template", msgs) + assert "<|turn>system\nInternal instructions" in out + + +def test_no_system_no_thinking_unchanged(): + msgs = [{"role": "user", "content": "Hi"}] + out = _render("gemma4_template", msgs) + assert "<|turn>user\nHi" in out + assert "<|turn>system" not in out + + +def test_assistant_role_renders_as_model_turn(): + msgs = [{"role": "user", "content": "Q"}, {"role": "assistant", "content": "A"}] + out = _render("gemma4_template", msgs) + assert "<|turn>model\nA" in out + assert "<|turn>assistant" not in out + + +def test_thinking_template_defaults_to_thinking_off_when_unset(): + msgs = [{"role": "user", "content": "Hi"}] + out = _render("gemma4_thinking_template", msgs) + assert "<|think|>" not in out + assert "<|turn>system" not in out + + +def test_thinking_template_emits_think_with_newline_when_enabled(): + msgs = [{"role": "system", "content": "Sys"}, {"role": "user", "content": "Hi"}] + out = _render("gemma4_thinking_template", msgs, enable_thinking = True) + assert "<|turn>system\n<|think|>\nSys" in out + + +def test_alternation_violation_raises_template_error(): + msgs = [{"role": "user", "content": "A"}, {"role": "user", "content": "B"}] + with pytest.raises(TemplateError): + _render("gemma4_template", msgs) + + +# ---------- strip_thinking macro semantics ---------- + + +def test_strip_thinking_strips_matched_pair(): + msgs = [ + {"role": "user", "content": "Q"}, + { + "role": "assistant", + "content": "<|channel>thought\n2+2=4The answer is 4.", + }, + ] + out = _render("gemma4_template", msgs) + assert "thought" not in out + assert "2+2=4" not in out + assert "The answer is 4." in out + + +def test_strip_thinking_applied_unconditionally_to_model_turn(): + msgs = [ + {"role": "user", "content": "Q"}, + {"role": "assistant", "content": "<|channel>reasoningfinal"}, + ] + for agp in (True, False): + out = _render("gemma4_template", msgs, add_generation_prompt = agp) + assert "reasoning" not in out + assert "final" in out + + +def test_strip_thinking_applies_to_iterable_text(): + msgs = [ + {"role": "user", "content": [{"type": "text", "text": "Q"}]}, + { + "role": "assistant", + "content": [{"type": "text", "text": "<|channel>rfinal"}], + }, + ] + out = _render("gemma4_thinking_template", msgs) + assert "final" in out + assert "<|channel>" not in out + + +def test_strip_thinking_preserves_plain_text(): + msgs = [ + {"role": "user", "content": "Q"}, + {"role": "assistant", "content": "plain answer with no markup"}, + ] + out = _render("gemma4_template", msgs, add_generation_prompt = True) + assert "plain answer with no markup" in out + + +def test_multi_turn_strips_all_historical_model_turns(): + msgs = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "content": "<|channel>r1A1"}, + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "<|channel>r2A2"}, + ] + out = _render("gemma4_thinking_template", msgs, add_generation_prompt = True) + assert "r1" not in out and "r2" not in out + assert "A1" in out and "A2" in out + + +# ---------- thinking-template gen-prompt injection ---------- + + +def test_thinking_template_injects_empty_thought_channel_by_default(): + # Author defaults enable_thinking=False, so the gen-prompt injection fires. + msgs = [{"role": "user", "content": "Hi"}] + out = _render("gemma4_thinking_template", msgs, add_generation_prompt = True) + assert out.endswith("<|turn>model\n<|channel>thought\n") + + +def test_thinking_template_no_injection_when_thinking_enabled(): + msgs = [{"role": "user", "content": "Hi"}] + out = _render( + "gemma4_thinking_template", + msgs, + add_generation_prompt = True, + enable_thinking = True, + ) + assert "<|channel>thought" not in out + + +def test_base_template_has_no_channel_thought_injection(): + msgs = [{"role": "user", "content": "Hi"}] + out = _render("gemma4_template", msgs, add_generation_prompt = True) + assert out.endswith("<|turn>model\n") + assert "<|channel>thought" not in out diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 326fd5928..8376dc7e3 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -866,12 +866,29 @@ DEFAULT_SYSTEM_MESSAGE["gemma3n"] = None # No system message in Gemma-3n # =========================================== Gemma-4 # Gemma-4 uses <|turn>role\n...\n format gemma4_template = \ -"""{%- if messages[0]['role'] == 'system' -%} - {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} - {%- set loop_messages = messages[1:] -%} -{%- else -%} - {%- set first_user_prefix = "" -%} - {%- set loop_messages = messages -%} +"""{%- macro strip_thinking(text) -%} + {%- set ns = namespace(result='') -%} + {%- for part in text.split('') -%} + {%- if '<|channel>' in part -%} + {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + {%- else -%} + {%- set ns.result = ns.result + part -%} + {%- endif -%} + {%- endfor -%} + {{- ns.result | trim -}} +{%- endmacro -%} +{%- set thinking = enable_thinking is defined and enable_thinking -%} +{%- set loop_messages = messages -%} +{%- if messages[0]['role'] in ['system', 'developer'] or thinking -%} + {{ '<|turn>system\n' }} + {%- if thinking -%} + {{ '<|think|>\n' }} + {%- endif -%} + {%- if messages[0]['role'] in ['system', 'developer'] -%} + {{ messages[0]['content'] | trim }} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + {{ '\n' }} {%- endif -%} {%- for message in loop_messages -%} {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} @@ -882,9 +899,13 @@ gemma4_template = \ {%- else -%} {%- set role = message['role'] -%} {%- endif -%} - {{ '<|turn>' + role + '\n' + (first_user_prefix if loop.first else "") }} + {{ '<|turn>' + role + '\n' }} {%- if message['content'] is string -%} - {{ message['content'] | trim }} + {%- if role == "model" -%} + {{ strip_thinking(message['content']) }} + {%- else -%} + {{ message['content'] | trim }} + {%- endif -%} {%- elif message['content'] is iterable -%} {%- for item in message['content'] -%} {%- if item['type'] == 'audio' -%} @@ -894,7 +915,11 @@ gemma4_template = \ {%- elif item['type'] == 'video' -%} {{ '<|video|>' }} {%- elif item['type'] == 'text' -%} - {{ item['text'] | trim }} + {%- if role == "model" -%} + {{ strip_thinking(item['text']) }} + {%- else -%} + {{ item['text'] | trim }} + {%- endif -%} {%- endif -%} {%- endfor -%} {%- else -%} @@ -918,15 +943,31 @@ DEFAULT_SYSTEM_MESSAGE["gemma-4"] = None CHAT_TEMPLATES["gemma4"] = (gemma4_template, gemma4_template_eos_token, False, gemma4_ollama,) DEFAULT_SYSTEM_MESSAGE["gemma4"] = None -# Gemma-4 with empty thought channel (required for larger models like 31B, 26B-A4B) -# Injects <|channel>thought\n at the start of each model response during training +# Gemma-4 thinking template gemma4_thinking_template = \ -"""{%- if messages[0]['role'] == 'system' -%} - {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} - {%- set loop_messages = messages[1:] -%} -{%- else -%} - {%- set first_user_prefix = "" -%} - {%- set loop_messages = messages -%} +"""{%- macro strip_thinking(text) -%} + {%- set ns = namespace(result='') -%} + {%- for part in text.split('') -%} + {%- if '<|channel>' in part -%} + {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + {%- else -%} + {%- set ns.result = ns.result + part -%} + {%- endif -%} + {%- endfor -%} + {{- ns.result | trim -}} +{%- endmacro -%} +{%- set thinking = enable_thinking is defined and enable_thinking -%} +{%- set loop_messages = messages -%} +{%- if messages[0]['role'] in ['system', 'developer'] or thinking -%} + {{ '<|turn>system\n' }} + {%- if thinking -%} + {{ '<|think|>\n' }} + {%- endif -%} + {%- if messages[0]['role'] in ['system', 'developer'] -%} + {{ messages[0]['content'] | trim }} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + {{ '\n' }} {%- endif -%} {%- for message in loop_messages -%} {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} @@ -937,12 +978,13 @@ gemma4_thinking_template = \ {%- else -%} {%- set role = message['role'] -%} {%- endif -%} - {{ '<|turn>' + role + '\n' + (first_user_prefix if loop.first else "") }} - {%- if role == "model" -%} - {{ '<|channel>thought\n' }} - {%- endif -%} + {{ '<|turn>' + role + '\n' }} {%- if message['content'] is string -%} - {{ message['content'] | trim }} + {%- if role == "model" -%} + {{ strip_thinking(message['content']) }} + {%- else -%} + {{ message['content'] | trim }} + {%- endif -%} {%- elif message['content'] is iterable -%} {%- for item in message['content'] -%} {%- if item['type'] == 'audio' -%} @@ -952,7 +994,11 @@ gemma4_thinking_template = \ {%- elif item['type'] == 'video' -%} {{ '<|video|>' }} {%- elif item['type'] == 'text' -%} - {{ item['text'] | trim }} + {%- if role == "model" -%} + {{ strip_thinking(item['text']) }} + {%- else -%} + {{ item['text'] | trim }} + {%- endif -%} {%- endif -%} {%- endfor -%} {%- else -%} @@ -962,6 +1008,9 @@ gemma4_thinking_template = \ {%- endfor -%} {%- if add_generation_prompt -%} {{'<|turn>model\n'}} + {%- if not thinking -%} + {{ '<|channel>thought\n' }} + {%- endif -%} {%- endif -%} """