mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
Add Gemma-4 to FORCE_FLOAT32 list to fix fp16 NaN in RL training
Gemma-4 (gemma4, gemma4text) produces NaN loss during GRPO/RL training when loaded in float16, matching the same behavior already seen with Gemma-3 and Gemma-3n. The NaN appears stochastically around step 4-5 and originates from the torch.compile backward path under fp16 autocast. Adding gemma4 and gemma4text to the FORCE_FLOAT32 list causes Unsloth to automatically switch to bfloat16 loading with float32 mixed precision when a user requests float16, preventing the overflow. This is the same fix already applied for gemma3, gemma3text, and gemma3n. Tested with the Gemma-4 E2B Sudoku RL notebook (GRPO, ga=4, num_generations=2, max_steps=10) on both fp16 and bf16 -- 10/10 steps NaN-free after the fix.
This commit is contained in:
parent
c5be8b1cd2
commit
06fbe97b54
1 changed files with 2 additions and 0 deletions
|
|
@ -105,6 +105,8 @@ FORCE_FLOAT32 = [
|
|||
"gemma3,", # Add comma bc gemma3 will match gemma3n
|
||||
"gemma3text", # Gemma3TextModel (EmbeddingGemma, standalone text-only Gemma3)
|
||||
"gemma3n",
|
||||
"gemma4,", # Gemma-4 fp16 produces NaN in RL training (same issue as Gemma-3)
|
||||
"gemma4text", # Gemma4TextModel (standalone text-only Gemma4)
|
||||
"gpt_oss",
|
||||
"qwen3_5", # Qwen3.5 GDN layers produce NaN grad norms in float16 training
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue