[docs]: add contribuing guide and add hooks install (#1613)

* [feat]: update kt-kernel hooks and add contribution guide

* [docs]: add contributing guide
* [style]: format the python file and cpp file in kt-kernel
This commit is contained in:
ZiWei Yuan 2025-11-15 18:26:49 +08:00 committed by GitHub
parent c32fefb1cd
commit aef6672dd8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 289 additions and 164 deletions

View file

@ -34,63 +34,42 @@ from datasets import load_dataset
def parse_args():
parser = argparse.ArgumentParser(description="Quantize MoE models with selective quantization")
# Required arguments
parser.add_argument(
"--model_id",
type=str,
required=True,
help="Path to the input model directory"
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Path to save the quantized model"
)
parser.add_argument("--model_id", type=str, required=True, help="Path to the input model directory")
parser.add_argument("--output_dir", type=str, required=True, help="Path to save the quantized model")
# Optional arguments
parser.add_argument(
"--quant_type",
type=str,
choices=["W4A16", "W8A16"],
default="W8A16",
help="Quantization type: W4A16 (GPTQ4) or W8A16 (GPTQ8). Default: W8A16"
help="Quantization type: W4A16 (GPTQ4) or W8A16 (GPTQ8). Default: W8A16",
)
parser.add_argument(
"--num_calibration_samples",
type=int,
default=512,
help="Number of calibration samples. Default: 512"
"--num_calibration_samples", type=int, default=512, help="Number of calibration samples. Default: 512"
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=2048,
help="Maximum sequence length for calibration. Default: 2048"
"--max_sequence_length", type=int, default=2048, help="Maximum sequence length for calibration. Default: 2048"
)
parser.add_argument(
"--dampening_frac",
type=float,
default=0.1,
help="Dampening fraction to mitigate quantization noise. Default: 0.1"
help="Dampening fraction to mitigate quantization noise. Default: 0.1",
)
parser.add_argument(
"--dataset",
type=str,
default="HuggingFaceH4/ultrachat_200k",
help="Dataset for calibration. Default: HuggingFaceH4/ultrachat_200k"
help="Dataset for calibration. Default: HuggingFaceH4/ultrachat_200k",
)
parser.add_argument(
"--dataset_split",
type=str,
default="train_sft",
help="Dataset split to use. Default: train_sft"
"--dataset_split", type=str, default="train_sft", help="Dataset split to use. Default: train_sft"
)
parser.add_argument(
"--force_cpu",
action="store_true",
help="Force all computations to CPU (sets CUDA_VISIBLE_DEVICES='')"
"--force_cpu", action="store_true", help="Force all computations to CPU (sets CUDA_VISIBLE_DEVICES='')"
)
parser.add_argument(
"--ignore_patterns",
@ -103,29 +82,22 @@ def parse_args():
r"re:.*\.shared_expert\..*$",
r"re:.*\.shared_experts\..*$",
r"re:.*\.mlp\.shared_expert_gate$",
r"re:.*\.linear_attn\..*$"
r"re:.*\.linear_attn\..*$",
],
help="Regex patterns for layers to ignore during quantization"
help="Regex patterns for layers to ignore during quantization",
)
parser.add_argument(
"--torch_dtype",
type=str,
choices=["bfloat16", "float16", "float32"],
default="bfloat16",
help="PyTorch dtype for model loading. Default: bfloat16"
help="PyTorch dtype for model loading. Default: bfloat16",
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Allow loading of remote code (required for some models)"
"--trust_remote_code", action="store_true", help="Allow loading of remote code (required for some models)"
)
parser.add_argument(
"--random_seed",
type=int,
default=42,
help="Random seed for dataset shuffling. Default: 42"
)
parser.add_argument("--random_seed", type=int, default=42, help="Random seed for dataset shuffling. Default: 42")
return parser.parse_args()
@ -152,11 +124,7 @@ def get_torch_dtype(dtype_str):
Returns:
torch.dtype: Corresponding PyTorch dtype
"""
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32
}
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
return dtype_map[dtype_str]
@ -176,18 +144,18 @@ def check_dense_layers_and_update_ignore(model_id, ignore_patterns, trust_remote
Updated ignore_patterns list with dense layer patterns added
"""
print("🔍 Checking model configuration for dense layers...")
try:
# Load model configuration
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
# Check if the model has first_k_dense_replace parameter
first_k_dense_replace = getattr(config, 'first_k_dense_replace', None)
first_k_dense_replace = getattr(config, "first_k_dense_replace", None)
if first_k_dense_replace is not None and first_k_dense_replace > 0:
print(f"✅ Found dense layers configuration: first_k_dense_replace = {first_k_dense_replace}")
print(f" Adding first {first_k_dense_replace} layers to ignore list...")
# Create regex pattern for dense layers (layers 0 to first_k_dense_replace-1)
if first_k_dense_replace == 1:
dense_pattern = r"re:model\.layers\.0\.mlp\..*$"
@ -195,18 +163,18 @@ def check_dense_layers_and_update_ignore(model_id, ignore_patterns, trust_remote
# For multiple layers, use range pattern
layer_range = f"[0-{first_k_dense_replace-1}]"
dense_pattern = f"re:model\\.layers\\.{layer_range}\\.mlp\\..*$"
# Add the dense layer pattern to ignore list
updated_ignore_patterns = ignore_patterns + [dense_pattern]
print(f" Dense layer pattern added: {dense_pattern}")
print(f" This will ignore MLP components in layers 0-{first_k_dense_replace-1}")
return updated_ignore_patterns
else:
print(" No dense layers detected (first_k_dense_replace not found or is 0)")
return ignore_patterns
except Exception as e:
print(f"⚠️ Warning: Could not check model config for dense layers: {e}")
print(" Proceeding with original ignore patterns...")
@ -246,11 +214,7 @@ def load_and_prepare_dataset(dataset_name, dataset_split, num_samples, max_lengt
# Tokenize the data
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=max_length,
truncation=True,
add_special_tokens=False
sample["text"], padding=False, max_length=max_length, truncation=True, add_special_tokens=False
)
ds = ds.map(tokenize, remove_columns=ds.column_names)
@ -291,9 +255,7 @@ def main():
# 0) Check for dense layers and update ignore patterns
# Dense layers in the first few layers should not be quantized
updated_ignore_patterns = check_dense_layers_and_update_ignore(
args.model_id,
args.ignore_patterns,
args.trust_remote_code
args.model_id, args.ignore_patterns, args.trust_remote_code
)
# --------------------------------------------------------------------
@ -302,13 +264,9 @@ def main():
print("🔍 Inferring device map...")
with init_empty_weights():
dummy = AutoModelForCausalLM.from_pretrained(
args.model_id,
torch_dtype=torch_dtype,
trust_remote_code=args.trust_remote_code
)
device_map = infer_auto_device_map(
dummy, no_split_module_classes=dummy._no_split_modules
args.model_id, torch_dtype=torch_dtype, trust_remote_code=args.trust_remote_code
)
device_map = infer_auto_device_map(dummy, no_split_module_classes=dummy._no_split_modules)
del dummy
# Force all modules to CPU for quantization
@ -335,7 +293,7 @@ def main():
args.num_calibration_samples,
args.max_sequence_length,
tokenizer,
args.random_seed
args.random_seed,
)
# --------------------------------------------------------------------
@ -373,4 +331,4 @@ def main():
if __name__ == "__main__":
main()
main()