Add Gemma-4 to FORCE_FLOAT32 list to fix fp16 NaN in RL training

Gemma-4 (gemma4, gemma4text) produces NaN loss during GRPO/RL training
when loaded in float16, matching the same behavior already seen with
Gemma-3 and Gemma-3n. The NaN appears stochastically around step 4-5
and originates from the torch.compile backward path under fp16 autocast.

Adding gemma4 and gemma4text to the FORCE_FLOAT32 list causes Unsloth
to automatically switch to bfloat16 loading with float32 mixed precision
when a user requests float16, preventing the overflow. This is the same
fix already applied for gemma3, gemma3text, and gemma3n.

Tested with the Gemma-4 E2B Sudoku RL notebook (GRPO, ga=4,
num_generations=2, max_steps=10) on both fp16 and bf16 -- 10/10 steps
NaN-free after the fix.
This commit is contained in:
Daniel Han 2026-04-16 13:22:47 +00:00
parent c5be8b1cd2
commit 06fbe97b54

View file

@ -105,6 +105,8 @@ FORCE_FLOAT32 = [
"gemma3,", # Add comma bc gemma3 will match gemma3n
"gemma3text", # Gemma3TextModel (EmbeddingGemma, standalone text-only Gemma3)
"gemma3n",
"gemma4,", # Gemma-4 fp16 produces NaN in RL training (same issue as Gemma-3)
"gemma4text", # Gemma4TextModel (standalone text-only Gemma4)
"gpt_oss",
"qwen3_5", # Qwen3.5 GDN layers produce NaN grad norms in float16 training
]