mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
273 lines
9.9 KiB
Python
273 lines
9.9 KiB
Python
# SPDX-License-Identifier: GNU Affero General Public License v3.0
|
|
# Copyright 2023-present the Unsloth team. All rights reserved.
|
|
|
|
import argparse
|
|
from contextlib import contextmanager
|
|
|
|
import pytest
|
|
import torch
|
|
from transformers import AutoConfig
|
|
from transformers.models.qwen3_moe import Qwen3MoeConfig
|
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
|
|
|
from grouped_gemm.kernels.tuning import (
|
|
KernelConfigBackward_dW,
|
|
KernelConfigBackward_dX,
|
|
KernelConfigForward,
|
|
)
|
|
from grouped_gemm.reference.layers.qwen3_moe import Qwen3MoeGroupedGEMMBlock
|
|
|
|
from .moe_utils import (
|
|
Qwen3MoeFusedGroupedGEMMBlock,
|
|
check_fwd,
|
|
check_grads,
|
|
check_grouped_gemm_results,
|
|
run_backward,
|
|
run_forward,
|
|
)
|
|
|
|
"""
|
|
Qwen3 MoE tests
|
|
|
|
NOTE: Test this as a module and NOT with pytest as running with pytest results in random numerical errors: python -m tests.test_qwen3_moe --permute_x --permute_y --autotune NOT pytest -sv tests/test_qwen3_moe.py
|
|
More specifically, all tests pass when run individually, but some will fail randomly (even with the same seed) when the entire test is run as a parametrized test suite using pytest, likely due to how pytest interacts with triton / autotuning.
|
|
|
|
See tests/run_qwen3_moe_tests.sh for a script that runs all the tests
|
|
|
|
The tests run the following:
|
|
Huggingface's Qwen3 MoE block (Qwen3MoeSparseMoeBlock)
|
|
Torch-native grouped gemm version of MoE block (Qwen3MoeGroupedGEMMBlock), which is the HF block with the expert computation replaced with a torch-native grouped gemm
|
|
Triton kernel grouped gemm version of MoE block (Qwen3MoeFusedGroupedGEMMBlock), which is the HF block with the expert computation replaced with the fused triton grouped gemm kernel
|
|
|
|
The tests check the following:
|
|
- HF MoE block vs torch grouped gemm MoE block (sanity check)
|
|
- torch grouped gemm MoE block vs fused grouped gemm MoE block -- this allows us to test each of the intermediate results for easier debugging
|
|
- HF MoE block vs fused grouped gemm MoE block -- this is the actual test
|
|
|
|
Both forward and backward passes are tests:
|
|
- forward: output of the moe block
|
|
- backwards:
|
|
- X: gradient of the input to the moe block
|
|
- gate.weight: gradient of the gate weights (router weights)
|
|
- gate_proj: gradient of concatenated gate projections
|
|
- up_proj: gradient of the concatenated up projections
|
|
- down_proj: gradient of the concatenated down projections
|
|
|
|
Additionally, for the torch grouped gemm and triton grouped gemm versions, the intermediate outputs of the forward pass are checked:
|
|
- first_gemm: output of the first grouped gemm (X @ fused_gate_proj)
|
|
- intermediate: output of silu_mul(first_gemm)
|
|
- second_gemm: output of the second grouped gemm (intermediate @ down_proj)
|
|
- hidden_states_unpermute: output of the second_gemm after unpermuting back to token order (from expert grouped order); in the case where the permutation is fused in the triton kernel, this is the same as second_gemm
|
|
- hidden_states: output with the topk_weights applied
|
|
"""
|
|
|
|
TOLERANCES = {
|
|
torch.bfloat16: (1e-2, 1e-2),
|
|
torch.float16: (1e-3, 1e-3),
|
|
torch.float: (1e-5, 1e-5),
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope = "module")
|
|
def model_id():
|
|
return "Qwen/Qwen3-30B-A3B"
|
|
|
|
|
|
@pytest.fixture(scope = "module")
|
|
def config(model_id: str):
|
|
return AutoConfig.from_pretrained(model_id)
|
|
|
|
|
|
@contextmanager
|
|
def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
|
|
print(char * num_chars)
|
|
print(prelude)
|
|
yield
|
|
print(epilogue)
|
|
print(char * num_chars)
|
|
|
|
|
|
SEED = 42
|
|
SEQ_LENS = [1024]
|
|
DTYPES = [torch.bfloat16]
|
|
|
|
# Reduce the number of autotuning configs to prevent excessive runtime
|
|
NUM_AUTOTUNE_CONFIGS = 50
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"permute_y", [True], ids = lambda x: "permute_y" if x else "no_permute_y"
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"permute_x", [True], ids = lambda x: "permute_x" if x else "no_permute_x"
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"autotune", [True], ids = lambda x: "autotune" if x else "manual"
|
|
)
|
|
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
|
|
@pytest.mark.parametrize("dtype", DTYPES, ids = str)
|
|
def test_qwen3_moe(
|
|
config: Qwen3MoeConfig,
|
|
seqlen: int,
|
|
dtype: torch.dtype,
|
|
permute_x: bool,
|
|
permute_y: bool,
|
|
autotune: bool,
|
|
):
|
|
torch.manual_seed(
|
|
SEED
|
|
) # Should not be needed when running using pytest -- autouse fixture in conftest.py
|
|
device = "cuda"
|
|
hidden_size = config.hidden_size
|
|
bs = 1
|
|
atol, rtol = TOLERANCES[dtype]
|
|
# Reference op -- HF
|
|
moe_block = Qwen3MoeSparseMoeBlock(config).to(device, dtype)
|
|
|
|
# Torch-native grouped gemm version of MoE Block -- for sanity checking
|
|
grouped_gemm_block = Qwen3MoeGroupedGEMMBlock.from_hf(moe_block).to(device, dtype)
|
|
grouped_gemm_block.check_weights(moe_block)
|
|
|
|
if not autotune:
|
|
kernel_config_fwd = KernelConfigForward()
|
|
kernel_config_bwd_dW = KernelConfigBackward_dW()
|
|
kernel_config_bwd_dX = KernelConfigBackward_dX()
|
|
else:
|
|
from grouped_gemm.kernels.backward import (
|
|
_autotuned_grouped_gemm_dW_kernel,
|
|
_autotuned_grouped_gemm_dX_kernel,
|
|
)
|
|
from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel
|
|
|
|
# Hack to reduce number of autotuning configs
|
|
_autotuned_grouped_gemm_forward_kernel.configs = (
|
|
_autotuned_grouped_gemm_forward_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
|
|
)
|
|
_autotuned_grouped_gemm_dW_kernel.configs = (
|
|
_autotuned_grouped_gemm_dW_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
|
|
)
|
|
_autotuned_grouped_gemm_dX_kernel.configs = (
|
|
_autotuned_grouped_gemm_dX_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
|
|
)
|
|
|
|
kernel_config_fwd = None
|
|
kernel_config_bwd_dW = None
|
|
kernel_config_bwd_dX = None
|
|
|
|
# Triton kernel grouped gemm version of MoE Block -- this is what we're testing
|
|
fused_gemm_block = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
|
|
moe_block,
|
|
permute_x = permute_x,
|
|
permute_y = permute_y,
|
|
autotune = autotune,
|
|
kernel_config_fwd = kernel_config_fwd,
|
|
kernel_config_bwd_dW = kernel_config_bwd_dW,
|
|
kernel_config_bwd_dX = kernel_config_bwd_dX,
|
|
).to(device, dtype)
|
|
fused_gemm_block.check_weights(moe_block)
|
|
|
|
X = torch.randn(
|
|
bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
|
|
)
|
|
|
|
# Forward
|
|
ref_result = run_forward(moe_block, X, is_grouped_gemm = False)
|
|
grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm = True)
|
|
fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm = True)
|
|
|
|
with annotated_context(
|
|
"Testing forward pass",
|
|
epilogue = "Passed forward tests!",
|
|
char = "=",
|
|
num_chars = 100,
|
|
):
|
|
# Sanity checks
|
|
|
|
with annotated_context(
|
|
"Checking HF vs torch grouped gemm MoE forward outputs..."
|
|
):
|
|
check_fwd(ref_result, grouped_result, atol, rtol, verbose = False)
|
|
|
|
with annotated_context(
|
|
"Checking torch grouped gemm MoE vs fused grouped gemm MoE forward outputs..."
|
|
):
|
|
# We implement a custom check for grouped gemm results to test each of the intermediate results for easier debugging
|
|
check_grouped_gemm_results(
|
|
grouped_result.grouped_gemm_result,
|
|
fused_result.grouped_gemm_result,
|
|
permute_y = permute_y,
|
|
atol = atol,
|
|
rtol = rtol,
|
|
verbose = False,
|
|
)
|
|
# Actual test
|
|
with annotated_context(
|
|
"Checking HF vs fused grouped gemm MoE forward outputs..."
|
|
):
|
|
check_fwd(ref_result, fused_result, atol, rtol, verbose = True)
|
|
|
|
# Backward
|
|
grad_output = torch.randn_like(ref_result.output)
|
|
ref_backward_result = run_backward(
|
|
moe_block, grad_output, output = ref_result.output, X = ref_result.X
|
|
)
|
|
grouped_backward_result = run_backward(
|
|
grouped_gemm_block,
|
|
grad_output,
|
|
output = grouped_result.output,
|
|
X = grouped_result.X,
|
|
)
|
|
fused_backward_result = run_backward(
|
|
fused_gemm_block, grad_output, output = fused_result.output, X = fused_result.X
|
|
)
|
|
|
|
with annotated_context(
|
|
"Testing backward pass",
|
|
epilogue = "Passed backward tests!",
|
|
char = "=",
|
|
num_chars = 100,
|
|
):
|
|
# Sanity checks
|
|
with annotated_context("Checking HF vs torch grouped gemm MoE grads..."):
|
|
check_grads(
|
|
ref_backward_result, grouped_backward_result, atol, rtol, verbose = False
|
|
)
|
|
with annotated_context(
|
|
"Checking torch grouped gemm MoE vs fused grouped gemm MoE grads..."
|
|
):
|
|
check_grads(
|
|
grouped_backward_result,
|
|
fused_backward_result,
|
|
atol,
|
|
rtol,
|
|
verbose = False,
|
|
)
|
|
|
|
# Actual test
|
|
with annotated_context("Checking HF vs fused grouped gemm MoE grads..."):
|
|
check_grads(
|
|
ref_backward_result, fused_backward_result, atol, rtol, verbose = True
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--seqlen", type = int, default = 1024)
|
|
parser.add_argument(
|
|
"--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
|
|
)
|
|
parser.add_argument("--permute_x", action = "store_true")
|
|
parser.add_argument("--permute_y", action = "store_true")
|
|
parser.add_argument("--autotune", action = "store_true")
|
|
args = parser.parse_args()
|
|
args.dtype = getattr(torch, args.dtype)
|
|
args_dict = vars(args)
|
|
|
|
model_id = "Qwen/Qwen3-30B-A3B"
|
|
config = AutoConfig.from_pretrained(model_id)
|
|
atol, rtol = TOLERANCES[args.dtype]
|
|
|
|
print(
|
|
f"Testing {model_id} with seqlen={args.seqlen}, dtype={args.dtype}, permute_x={args.permute_x}, permute_y={args.permute_y}, autotune={args.autotune}, atol={atol}, rtol={rtol}"
|
|
)
|
|
test_qwen3_moe(config, **args_dict)
|