mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-26 10:31:03 +00:00
update gema4 chat templates (#5116)
* update gema4 chat templates * udpate template * update template for gemma4 * Add gemma4 chat template tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
77756faa46
commit
f9682e656c
2 changed files with 255 additions and 23 deletions
183
tests/test_gemma4_chat_template.py
Normal file
183
tests/test_gemma4_chat_template.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
from jinja2.exceptions import TemplateError
|
||||
|
||||
|
||||
CHAT_TEMPLATES_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"unsloth",
|
||||
"chat_templates.py",
|
||||
)
|
||||
|
||||
|
||||
def _extract_template(name):
|
||||
src = open(CHAT_TEMPLATES_PATH).read()
|
||||
pattern = rf'{re.escape(name)}\s*=\s*\\\n"""(.*?)"""'
|
||||
m = re.search(pattern, src, flags = re.DOTALL)
|
||||
assert m, f"Could not extract {name} from chat_templates.py"
|
||||
return m.group(1)
|
||||
|
||||
|
||||
def _env():
|
||||
env = Environment(undefined = StrictUndefined, trim_blocks = False, lstrip_blocks = False)
|
||||
env.globals["raise_exception"] = lambda msg: (_ for _ in ()).throw(
|
||||
TemplateError(msg)
|
||||
)
|
||||
return env
|
||||
|
||||
|
||||
def _render(template_name, messages, **kwargs):
|
||||
src = _extract_template(template_name)
|
||||
tmpl = _env().from_string(src)
|
||||
ctx = {"messages": messages, "add_generation_prompt": False}
|
||||
ctx.update(kwargs)
|
||||
return tmpl.render(**ctx)
|
||||
|
||||
|
||||
# ---------- system turn and <|think|> placement ----------
|
||||
|
||||
|
||||
def test_system_message_emits_dedicated_system_turn():
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
out = _render("gemma4_template", msgs)
|
||||
assert "<|turn>system\nYou are helpful<turn|>" in out
|
||||
assert "<|turn>user\nHi<turn|>" in out
|
||||
assert "You are helpful\n\nHi" not in out
|
||||
|
||||
|
||||
def test_developer_role_treated_as_system():
|
||||
msgs = [
|
||||
{"role": "developer", "content": "Internal instructions"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
out = _render("gemma4_template", msgs)
|
||||
assert "<|turn>system\nInternal instructions<turn|>" in out
|
||||
|
||||
|
||||
def test_no_system_no_thinking_unchanged():
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
out = _render("gemma4_template", msgs)
|
||||
assert "<|turn>user\nHi<turn|>" in out
|
||||
assert "<|turn>system" not in out
|
||||
|
||||
|
||||
def test_assistant_role_renders_as_model_turn():
|
||||
msgs = [{"role": "user", "content": "Q"}, {"role": "assistant", "content": "A"}]
|
||||
out = _render("gemma4_template", msgs)
|
||||
assert "<|turn>model\nA<turn|>" in out
|
||||
assert "<|turn>assistant" not in out
|
||||
|
||||
|
||||
def test_thinking_template_defaults_to_thinking_off_when_unset():
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
out = _render("gemma4_thinking_template", msgs)
|
||||
assert "<|think|>" not in out
|
||||
assert "<|turn>system" not in out
|
||||
|
||||
|
||||
def test_thinking_template_emits_think_with_newline_when_enabled():
|
||||
msgs = [{"role": "system", "content": "Sys"}, {"role": "user", "content": "Hi"}]
|
||||
out = _render("gemma4_thinking_template", msgs, enable_thinking = True)
|
||||
assert "<|turn>system\n<|think|>\nSys<turn|>" in out
|
||||
|
||||
|
||||
def test_alternation_violation_raises_template_error():
|
||||
msgs = [{"role": "user", "content": "A"}, {"role": "user", "content": "B"}]
|
||||
with pytest.raises(TemplateError):
|
||||
_render("gemma4_template", msgs)
|
||||
|
||||
|
||||
# ---------- strip_thinking macro semantics ----------
|
||||
|
||||
|
||||
def test_strip_thinking_strips_matched_pair():
|
||||
msgs = [
|
||||
{"role": "user", "content": "Q"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<|channel>thought\n2+2=4<channel|>The answer is 4.",
|
||||
},
|
||||
]
|
||||
out = _render("gemma4_template", msgs)
|
||||
assert "thought" not in out
|
||||
assert "2+2=4" not in out
|
||||
assert "The answer is 4." in out
|
||||
|
||||
|
||||
def test_strip_thinking_applied_unconditionally_to_model_turn():
|
||||
msgs = [
|
||||
{"role": "user", "content": "Q"},
|
||||
{"role": "assistant", "content": "<|channel>reasoning<channel|>final"},
|
||||
]
|
||||
for agp in (True, False):
|
||||
out = _render("gemma4_template", msgs, add_generation_prompt = agp)
|
||||
assert "reasoning" not in out
|
||||
assert "final" in out
|
||||
|
||||
|
||||
def test_strip_thinking_applies_to_iterable_text():
|
||||
msgs = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Q"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "<|channel>r<channel|>final"}],
|
||||
},
|
||||
]
|
||||
out = _render("gemma4_thinking_template", msgs)
|
||||
assert "final" in out
|
||||
assert "<|channel>" not in out
|
||||
|
||||
|
||||
def test_strip_thinking_preserves_plain_text():
|
||||
msgs = [
|
||||
{"role": "user", "content": "Q"},
|
||||
{"role": "assistant", "content": "plain answer with no markup"},
|
||||
]
|
||||
out = _render("gemma4_template", msgs, add_generation_prompt = True)
|
||||
assert "plain answer with no markup" in out
|
||||
|
||||
|
||||
def test_multi_turn_strips_all_historical_model_turns():
|
||||
msgs = [
|
||||
{"role": "user", "content": "Q1"},
|
||||
{"role": "assistant", "content": "<|channel>r1<channel|>A1"},
|
||||
{"role": "user", "content": "Q2"},
|
||||
{"role": "assistant", "content": "<|channel>r2<channel|>A2"},
|
||||
]
|
||||
out = _render("gemma4_thinking_template", msgs, add_generation_prompt = True)
|
||||
assert "r1" not in out and "r2" not in out
|
||||
assert "A1" in out and "A2" in out
|
||||
|
||||
|
||||
# ---------- thinking-template gen-prompt injection ----------
|
||||
|
||||
|
||||
def test_thinking_template_injects_empty_thought_channel_by_default():
|
||||
# Author defaults enable_thinking=False, so the gen-prompt injection fires.
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
out = _render("gemma4_thinking_template", msgs, add_generation_prompt = True)
|
||||
assert out.endswith("<|turn>model\n<|channel>thought\n<channel|>")
|
||||
|
||||
|
||||
def test_thinking_template_no_injection_when_thinking_enabled():
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
out = _render(
|
||||
"gemma4_thinking_template",
|
||||
msgs,
|
||||
add_generation_prompt = True,
|
||||
enable_thinking = True,
|
||||
)
|
||||
assert "<|channel>thought" not in out
|
||||
|
||||
|
||||
def test_base_template_has_no_channel_thought_injection():
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
out = _render("gemma4_template", msgs, add_generation_prompt = True)
|
||||
assert out.endswith("<|turn>model\n")
|
||||
assert "<|channel>thought" not in out
|
||||
|
|
@ -866,12 +866,29 @@ DEFAULT_SYSTEM_MESSAGE["gemma3n"] = None # No system message in Gemma-3n
|
|||
# =========================================== Gemma-4
|
||||
# Gemma-4 uses <|turn>role\n...<turn|>\n format
|
||||
gemma4_template = \
|
||||
"""{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = "" -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
"""{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
{%- set thinking = enable_thinking is defined and enable_thinking -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] or thinking -%}
|
||||
{{ '<|turn>system\n' }}
|
||||
{%- if thinking -%}
|
||||
{{ '<|think|>\n' }}
|
||||
{%- endif -%}
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{ messages[0]['content'] | trim }}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{{ '<turn|>\n' }}
|
||||
{%- endif -%}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||
|
|
@ -882,9 +899,13 @@ gemma4_template = \
|
|||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{{ '<|turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
|
||||
{{ '<|turn>' + role + '\n' }}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- if role == "model" -%}
|
||||
{{ strip_thinking(message['content']) }}
|
||||
{%- else -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'audio' -%}
|
||||
|
|
@ -894,7 +915,11 @@ gemma4_template = \
|
|||
{%- elif item['type'] == 'video' -%}
|
||||
{{ '<|video|>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- if role == "model" -%}
|
||||
{{ strip_thinking(item['text']) }}
|
||||
{%- else -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
|
|
@ -918,15 +943,31 @@ DEFAULT_SYSTEM_MESSAGE["gemma-4"] = None
|
|||
CHAT_TEMPLATES["gemma4"] = (gemma4_template, gemma4_template_eos_token, False, gemma4_ollama,)
|
||||
DEFAULT_SYSTEM_MESSAGE["gemma4"] = None
|
||||
|
||||
# Gemma-4 with empty thought channel (required for larger models like 31B, 26B-A4B)
|
||||
# Injects <|channel>thought\n<channel|> at the start of each model response during training
|
||||
# Gemma-4 thinking template
|
||||
gemma4_thinking_template = \
|
||||
"""{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = "" -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
"""{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
{%- set thinking = enable_thinking is defined and enable_thinking -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] or thinking -%}
|
||||
{{ '<|turn>system\n' }}
|
||||
{%- if thinking -%}
|
||||
{{ '<|think|>\n' }}
|
||||
{%- endif -%}
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{ messages[0]['content'] | trim }}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{{ '<turn|>\n' }}
|
||||
{%- endif -%}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||
|
|
@ -937,12 +978,13 @@ gemma4_thinking_template = \
|
|||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{{ '<|turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
|
||||
{%- if role == "model" -%}
|
||||
{{ '<|channel>thought\n<channel|>' }}
|
||||
{%- endif -%}
|
||||
{{ '<|turn>' + role + '\n' }}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- if role == "model" -%}
|
||||
{{ strip_thinking(message['content']) }}
|
||||
{%- else -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'audio' -%}
|
||||
|
|
@ -952,7 +994,11 @@ gemma4_thinking_template = \
|
|||
{%- elif item['type'] == 'video' -%}
|
||||
{{ '<|video|>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- if role == "model" -%}
|
||||
{{ strip_thinking(item['text']) }}
|
||||
{%- else -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
|
|
@ -962,6 +1008,9 @@ gemma4_thinking_template = \
|
|||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{'<|turn>model\n'}}
|
||||
{%- if not thinking -%}
|
||||
{{ '<|channel>thought\n<channel|>' }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
"""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue