unsloth/tests/utils
Avaya Aggarwal 7c5464ad71
feat: Add cactus QAT scheme support (#4679)
* feat: Add cactus QAT scheme support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test(qat): add tests for cactus QAT scheme and fix missing import

* Fix cactus QAT scheme: correct MappingType import, tighten PerGroup filter

- Drop the broken `from torchao.dtypes import MappingType` import. `MappingType`
  lives in `torchao.quantization` (and `torchao.quantization.quant_primitives`);
  it is not exported from `torchao.dtypes` in any supported torchao release
  (verified on 0.14, 0.16, 0.17). The previous code raised `ImportError` on
  every cactus call and was masked as a misleading 'torchao not found' error.
- Since `IntxWeightOnlyConfig` already defaults `mapping_type` to
  `MappingType.SYMMETRIC`, drop the explicit kwarg entirely and remove the
  import. Behavior is unchanged.
- Introduce a named `group_size = 32` constant (matches the int4 / fp8-int4
  pattern in the surrounding branches) and add a `% group_size == 0`
  divisibility guard to the filter. `PerGroup(32)` requires
  `in_features % 32 == 0` at `quantize_()` time, otherwise torchao raises
  `ValueError: in_features (N) % group_size (32) must be == 0`. The old
  `in_features >= 32` filter would admit non-aligned widths (e.g. 33, 48, 65,
  127) and crash `_prepare_model_for_qat` for those shapes.

* Warn when cactus QAT skips non-divisible Linear layers

Multiple reviewers flagged that the divisibility guard added in the
previous commit can silently leave Linear layers in full precision when
their in_features is not a multiple of 32. For currently supported
Unsloth models (Qwen, Llama, Gemma, Mistral, Phi) every Linear width is
already a multiple of 32/64/128 so this never triggers, but surfacing
the coverage gap is cheap and avoids users assuming 100% QAT coverage
when they bring a custom model with unusual shapes.

Emit a UserWarning listing up to the first 8 skipped layers whenever
the cactus filter excludes any Linear due to the modulo guard. This
keeps the lenient silent-skip behavior (consistent with int4 /
fp8-int4), but stops making it silent.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2026-04-15 07:40:03 -07:00
..
__init__.py Revert "[FIX] Vllm guided decoding params (#3662)" 2025-12-01 05:43:45 -08:00
aime_eval.md reroute merge logic language models + comprehensive tests + eval kits (#2673) 2025-06-02 20:32:57 -07:00
aime_eval.py Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" 2025-12-01 07:24:58 -08:00
cleanup_utils.py Revert "[FIX] Vllm guided decoding params (#3662)" 2025-12-01 05:43:45 -08:00
data_utils.py Revert "[FIX] Vllm guided decoding params (#3662)" 2025-12-01 05:43:45 -08:00
hf_utils.py Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" 2025-12-01 07:24:58 -08:00
ocr_eval.md Fix Typos in Documentation and Comments (#2721) 2025-06-17 04:34:51 -07:00
ocr_eval.py Revert "[FIX] Vllm guided decoding params (#3662)" 2025-12-01 05:43:45 -08:00
os_utils.py Revert "[FIX] Vllm guided decoding params (#3662)" 2025-12-01 05:43:45 -08:00
perplexity_eval.md reroute merge logic language models + comprehensive tests + eval kits (#2673) 2025-06-02 20:32:57 -07:00
perplexity_eval.py Revert "[FIX] Vllm guided decoding params (#3662)" 2025-12-01 05:43:45 -08:00
test_attention_masks.py SFT sample packing (#3566) 2025-12-09 17:36:45 -08:00
test_packing.py Refactor Ollama template wiring and harden packing helpers (#3890) 2026-02-09 04:04:48 -08:00
test_q_galore.py feat: Implement Q-GaLore optimizer and custom embedding learning rate… (#4511) 2026-03-25 01:03:10 -07:00
test_qat.py feat: Add cactus QAT scheme support (#4679) 2026-04-15 07:40:03 -07:00
test_trunc_normal_patch.py Patch trunc_normal_ for low-precision stability (#4027) 2026-02-19 04:40:14 -08:00