update gema4 chat templates (#5116)

* update gema4 chat templates

* udpate template

* update template for gemma4

* Add gemma4 chat template tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Datta Nimmaturi 2026-04-22 21:34:08 +05:30 committed by GitHub
parent 77756faa46
commit f9682e656c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 255 additions and 23 deletions

View file

@ -0,0 +1,183 @@
import os
import re
import pytest
from jinja2 import Environment, StrictUndefined
from jinja2.exceptions import TemplateError
CHAT_TEMPLATES_PATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"unsloth",
"chat_templates.py",
)
def _extract_template(name):
src = open(CHAT_TEMPLATES_PATH).read()
pattern = rf'{re.escape(name)}\s*=\s*\\\n"""(.*?)"""'
m = re.search(pattern, src, flags = re.DOTALL)
assert m, f"Could not extract {name} from chat_templates.py"
return m.group(1)
def _env():
env = Environment(undefined = StrictUndefined, trim_blocks = False, lstrip_blocks = False)
env.globals["raise_exception"] = lambda msg: (_ for _ in ()).throw(
TemplateError(msg)
)
return env
def _render(template_name, messages, **kwargs):
src = _extract_template(template_name)
tmpl = _env().from_string(src)
ctx = {"messages": messages, "add_generation_prompt": False}
ctx.update(kwargs)
return tmpl.render(**ctx)
# ---------- system turn and <|think|> placement ----------
def test_system_message_emits_dedicated_system_turn():
msgs = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hi"},
]
out = _render("gemma4_template", msgs)
assert "<|turn>system\nYou are helpful<turn|>" in out
assert "<|turn>user\nHi<turn|>" in out
assert "You are helpful\n\nHi" not in out
def test_developer_role_treated_as_system():
msgs = [
{"role": "developer", "content": "Internal instructions"},
{"role": "user", "content": "Hi"},
]
out = _render("gemma4_template", msgs)
assert "<|turn>system\nInternal instructions<turn|>" in out
def test_no_system_no_thinking_unchanged():
msgs = [{"role": "user", "content": "Hi"}]
out = _render("gemma4_template", msgs)
assert "<|turn>user\nHi<turn|>" in out
assert "<|turn>system" not in out
def test_assistant_role_renders_as_model_turn():
msgs = [{"role": "user", "content": "Q"}, {"role": "assistant", "content": "A"}]
out = _render("gemma4_template", msgs)
assert "<|turn>model\nA<turn|>" in out
assert "<|turn>assistant" not in out
def test_thinking_template_defaults_to_thinking_off_when_unset():
msgs = [{"role": "user", "content": "Hi"}]
out = _render("gemma4_thinking_template", msgs)
assert "<|think|>" not in out
assert "<|turn>system" not in out
def test_thinking_template_emits_think_with_newline_when_enabled():
msgs = [{"role": "system", "content": "Sys"}, {"role": "user", "content": "Hi"}]
out = _render("gemma4_thinking_template", msgs, enable_thinking = True)
assert "<|turn>system\n<|think|>\nSys<turn|>" in out
def test_alternation_violation_raises_template_error():
msgs = [{"role": "user", "content": "A"}, {"role": "user", "content": "B"}]
with pytest.raises(TemplateError):
_render("gemma4_template", msgs)
# ---------- strip_thinking macro semantics ----------
def test_strip_thinking_strips_matched_pair():
msgs = [
{"role": "user", "content": "Q"},
{
"role": "assistant",
"content": "<|channel>thought\n2+2=4<channel|>The answer is 4.",
},
]
out = _render("gemma4_template", msgs)
assert "thought" not in out
assert "2+2=4" not in out
assert "The answer is 4." in out
def test_strip_thinking_applied_unconditionally_to_model_turn():
msgs = [
{"role": "user", "content": "Q"},
{"role": "assistant", "content": "<|channel>reasoning<channel|>final"},
]
for agp in (True, False):
out = _render("gemma4_template", msgs, add_generation_prompt = agp)
assert "reasoning" not in out
assert "final" in out
def test_strip_thinking_applies_to_iterable_text():
msgs = [
{"role": "user", "content": [{"type": "text", "text": "Q"}]},
{
"role": "assistant",
"content": [{"type": "text", "text": "<|channel>r<channel|>final"}],
},
]
out = _render("gemma4_thinking_template", msgs)
assert "final" in out
assert "<|channel>" not in out
def test_strip_thinking_preserves_plain_text():
msgs = [
{"role": "user", "content": "Q"},
{"role": "assistant", "content": "plain answer with no markup"},
]
out = _render("gemma4_template", msgs, add_generation_prompt = True)
assert "plain answer with no markup" in out
def test_multi_turn_strips_all_historical_model_turns():
msgs = [
{"role": "user", "content": "Q1"},
{"role": "assistant", "content": "<|channel>r1<channel|>A1"},
{"role": "user", "content": "Q2"},
{"role": "assistant", "content": "<|channel>r2<channel|>A2"},
]
out = _render("gemma4_thinking_template", msgs, add_generation_prompt = True)
assert "r1" not in out and "r2" not in out
assert "A1" in out and "A2" in out
# ---------- thinking-template gen-prompt injection ----------
def test_thinking_template_injects_empty_thought_channel_by_default():
# Author defaults enable_thinking=False, so the gen-prompt injection fires.
msgs = [{"role": "user", "content": "Hi"}]
out = _render("gemma4_thinking_template", msgs, add_generation_prompt = True)
assert out.endswith("<|turn>model\n<|channel>thought\n<channel|>")
def test_thinking_template_no_injection_when_thinking_enabled():
msgs = [{"role": "user", "content": "Hi"}]
out = _render(
"gemma4_thinking_template",
msgs,
add_generation_prompt = True,
enable_thinking = True,
)
assert "<|channel>thought" not in out
def test_base_template_has_no_channel_thought_injection():
msgs = [{"role": "user", "content": "Hi"}]
out = _render("gemma4_template", msgs, add_generation_prompt = True)
assert out.endswith("<|turn>model\n")
assert "<|channel>thought" not in out

View file

@ -866,12 +866,29 @@ DEFAULT_SYSTEM_MESSAGE["gemma3n"] = None # No system message in Gemma-3n
# =========================================== Gemma-4
# Gemma-4 uses <|turn>role\n...<turn|>\n format
gemma4_template = \
"""{%- if messages[0]['role'] == 'system' -%}
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
{%- set first_user_prefix = "" -%}
{%- set loop_messages = messages -%}
"""{%- macro strip_thinking(text) -%}
{%- set ns = namespace(result='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.result = ns.result + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.result | trim -}}
{%- endmacro -%}
{%- set thinking = enable_thinking is defined and enable_thinking -%}
{%- set loop_messages = messages -%}
{%- if messages[0]['role'] in ['system', 'developer'] or thinking -%}
{{ '<|turn>system\n' }}
{%- if thinking -%}
{{ '<|think|>\n' }}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{ messages[0]['content'] | trim }}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{{ '<turn|>\n' }}
{%- endif -%}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
@ -882,9 +899,13 @@ gemma4_template = \
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<|turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
{{ '<|turn>' + role + '\n' }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- if role == "model" -%}
{{ strip_thinking(message['content']) }}
{%- else -%}
{{ message['content'] | trim }}
{%- endif -%}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'audio' -%}
@ -894,7 +915,11 @@ gemma4_template = \
{%- elif item['type'] == 'video' -%}
{{ '<|video|>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- if role == "model" -%}
{{ strip_thinking(item['text']) }}
{%- else -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- else -%}
@ -918,15 +943,31 @@ DEFAULT_SYSTEM_MESSAGE["gemma-4"] = None
CHAT_TEMPLATES["gemma4"] = (gemma4_template, gemma4_template_eos_token, False, gemma4_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma4"] = None
# Gemma-4 with empty thought channel (required for larger models like 31B, 26B-A4B)
# Injects <|channel>thought\n<channel|> at the start of each model response during training
# Gemma-4 thinking template
gemma4_thinking_template = \
"""{%- if messages[0]['role'] == 'system' -%}
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
{%- set first_user_prefix = "" -%}
{%- set loop_messages = messages -%}
"""{%- macro strip_thinking(text) -%}
{%- set ns = namespace(result='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.result = ns.result + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.result | trim -}}
{%- endmacro -%}
{%- set thinking = enable_thinking is defined and enable_thinking -%}
{%- set loop_messages = messages -%}
{%- if messages[0]['role'] in ['system', 'developer'] or thinking -%}
{{ '<|turn>system\n' }}
{%- if thinking -%}
{{ '<|think|>\n' }}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{ messages[0]['content'] | trim }}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{{ '<turn|>\n' }}
{%- endif -%}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
@ -937,12 +978,13 @@ gemma4_thinking_template = \
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<|turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
{%- if role == "model" -%}
{{ '<|channel>thought\n<channel|>' }}
{%- endif -%}
{{ '<|turn>' + role + '\n' }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- if role == "model" -%}
{{ strip_thinking(message['content']) }}
{%- else -%}
{{ message['content'] | trim }}
{%- endif -%}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'audio' -%}
@ -952,7 +994,11 @@ gemma4_thinking_template = \
{%- elif item['type'] == 'video' -%}
{{ '<|video|>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- if role == "model" -%}
{{ strip_thinking(item['text']) }}
{%- else -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- else -%}
@ -962,6 +1008,9 @@ gemma4_thinking_template = \
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<|turn>model\n'}}
{%- if not thinking -%}
{{ '<|channel>thought\n<channel|>' }}
{%- endif -%}
{%- endif -%}
"""