mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 03:39:48 +00:00
[docs]: add contribuing guide and add hooks install (#1613)
* [feat]: update kt-kernel hooks and add contribution guide * [docs]: add contributing guide * [style]: format the python file and cpp file in kt-kernel
This commit is contained in:
parent
c32fefb1cd
commit
aef6672dd8
11 changed files with 289 additions and 164 deletions
|
|
@ -22,7 +22,6 @@ import triton
|
|||
import triton.language as tl
|
||||
|
||||
|
||||
|
||||
Q_BITS = 4
|
||||
STORAGE_BITS = 32
|
||||
PACK_NUM = STORAGE_BITS // Q_BITS
|
||||
|
|
@ -31,6 +30,7 @@ NUMA_NUM = 2
|
|||
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
|
|
@ -51,10 +51,11 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
|||
assert x.dim() == 2 and s.dim() == 2
|
||||
M, N = x.size()
|
||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||
return y
|
||||
|
||||
|
||||
def load_model_config(input_path: str, input_type: str = None) -> Dict:
|
||||
"""Load model configuration from config.json
|
||||
|
||||
|
|
@ -297,7 +298,6 @@ class ConverterBase:
|
|||
handle = self.file_handle_map[file]
|
||||
return handle.get_tensor(key)
|
||||
|
||||
|
||||
# layers_id -> list[experts_id]
|
||||
def _find_expert_layers(self) -> Dict[int, List[int]]:
|
||||
"""Find all layers and experts in the model"""
|
||||
|
|
@ -517,7 +517,9 @@ class OnlineQuantConverter(ConverterBase):
|
|||
quant_method: str = "int4",
|
||||
merge_to_safetensor: bool = True,
|
||||
):
|
||||
super().__init__(input_path, output_path, model_config, cpuinfer_threads, threadpool_count, input_type, merge_to_safetensor)
|
||||
super().__init__(
|
||||
input_path, output_path, model_config, cpuinfer_threads, threadpool_count, input_type, merge_to_safetensor
|
||||
)
|
||||
self.quant_method = quant_method
|
||||
|
||||
# For FP8, get block size from model_config
|
||||
|
|
@ -569,11 +571,11 @@ class OnlineQuantConverter(ConverterBase):
|
|||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
binary_data = f.read()
|
||||
|
||||
# Determine dtype based on file name
|
||||
if 'scale' in file_path:
|
||||
if "scale" in file_path:
|
||||
# Scale tensors are typically float32
|
||||
np_array = np.frombuffer(binary_data, dtype=np.float32)
|
||||
else:
|
||||
|
|
@ -616,22 +618,12 @@ class OnlineQuantConverter(ConverterBase):
|
|||
# Iterate through all experts
|
||||
for expert_id in range(self.num_experts):
|
||||
# For each projection (down, gate, up)
|
||||
proj_mappings = [
|
||||
('down', 'ffn_down_exps'),
|
||||
('gate', 'ffn_gate_exps'),
|
||||
('up', 'ffn_up_exps')
|
||||
]
|
||||
proj_mappings = [("down", "ffn_down_exps"), ("gate", "ffn_gate_exps"), ("up", "ffn_up_exps")]
|
||||
|
||||
for proj_name, proj_key in proj_mappings:
|
||||
# Build file patterns
|
||||
quant_pattern = os.path.join(
|
||||
numa_folder,
|
||||
f'{amx_method}_{proj_name}_{expert_id}_*Byte_quant_.kt'
|
||||
)
|
||||
scale_pattern = os.path.join(
|
||||
numa_folder,
|
||||
f'{amx_method}_{proj_name}_{expert_id}_*Byte_scale_.kt'
|
||||
)
|
||||
quant_pattern = os.path.join(numa_folder, f"{amx_method}_{proj_name}_{expert_id}_*Byte_quant_.kt")
|
||||
scale_pattern = os.path.join(numa_folder, f"{amx_method}_{proj_name}_{expert_id}_*Byte_scale_.kt")
|
||||
|
||||
# Find files using glob
|
||||
quant_files = glob.glob(quant_pattern)
|
||||
|
|
@ -705,18 +697,18 @@ class OnlineQuantConverter(ConverterBase):
|
|||
raise KeyError(f"Missing down weight_scale_inv for layer {layer_idx}, expert {expert_id}")
|
||||
|
||||
# Load FP8 weights and scales
|
||||
gate_fp8 = self._load_tensor(gate_key).to('cuda')
|
||||
up_fp8 = self._load_tensor(up_key).to('cuda')
|
||||
down_fp8 = self._load_tensor(down_key).to('cuda')
|
||||
gate_fp8 = self._load_tensor(gate_key).to("cuda")
|
||||
up_fp8 = self._load_tensor(up_key).to("cuda")
|
||||
down_fp8 = self._load_tensor(down_key).to("cuda")
|
||||
|
||||
gate_scale_inv = self._load_tensor(gate_scale_key).to('cuda')
|
||||
up_scale_inv = self._load_tensor(up_scale_key).to('cuda')
|
||||
down_scale_inv = self._load_tensor(down_scale_key).to('cuda')
|
||||
gate_scale_inv = self._load_tensor(gate_scale_key).to("cuda")
|
||||
up_scale_inv = self._load_tensor(up_scale_key).to("cuda")
|
||||
down_scale_inv = self._load_tensor(down_scale_key).to("cuda")
|
||||
|
||||
# Dequantize FP8 to BF16 using block-wise scaling
|
||||
gate_weight = weight_dequant(gate_fp8, gate_scale_inv).to('cpu').to(torch.bfloat16).contiguous()
|
||||
up_weight = weight_dequant(up_fp8, up_scale_inv).to('cpu').to(torch.bfloat16).contiguous()
|
||||
down_weight = weight_dequant(down_fp8, down_scale_inv).to('cpu').to(torch.bfloat16).contiguous()
|
||||
gate_weight = weight_dequant(gate_fp8, gate_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
|
||||
up_weight = weight_dequant(up_fp8, up_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
|
||||
down_weight = weight_dequant(down_fp8, down_scale_inv).to("cpu").to(torch.bfloat16).contiguous()
|
||||
|
||||
elif self.input_type == "fp16":
|
||||
# Load FP16 and convert to BF16
|
||||
|
|
@ -804,6 +796,7 @@ class OnlineQuantConverter(ConverterBase):
|
|||
print(f" Keeping layer folder structure at {self.output_path}/_layer_{layer_idx}")
|
||||
return {}
|
||||
|
||||
|
||||
"""
|
||||
Example usage(test passed):
|
||||
python convert_cpu_weights.py --input-path /mnt/data3/models/DeepSeek-R1-0528/ --input-type fp8 --output /mnt/data3/models/DeepSeek-R1-0528-INT4-test --quant-method int4 --cpuinfer-threads 60 --threadpool-count 2
|
||||
|
|
@ -811,6 +804,7 @@ python convert_cpu_weights.py --input-path /mnt/data3/models/DeepSeek-R1-0528/ -
|
|||
python convert_cpu_weights.py --input-path /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct --input-type bf16 --output /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-INT4-test --quant-method int4 --cpuinfer-threads 60 --threadpool-count 2
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert SafeTensors to column major 1D format")
|
||||
parser.add_argument("--input-path", "-i", required=True, help="Input directory with safetensors")
|
||||
|
|
@ -873,12 +867,25 @@ def main():
|
|||
|
||||
if quant_method == "awq":
|
||||
converter = AWQToColumnMajorConverter(
|
||||
args.input_path, args.output, model_config, args.cpuinfer_threads, args.threadpool_count, input_type=None, merge_to_safetensor=merge_to_safetensor
|
||||
args.input_path,
|
||||
args.output,
|
||||
model_config,
|
||||
args.cpuinfer_threads,
|
||||
args.threadpool_count,
|
||||
input_type=None,
|
||||
merge_to_safetensor=merge_to_safetensor,
|
||||
)
|
||||
elif quant_method in ["int4", "int8"] and args.input_type in ["fp8", "fp16", "bf16"]:
|
||||
# Use OnlineQuantConverter for both INT4 and INT8 quantization
|
||||
converter = OnlineQuantConverter(
|
||||
args.input_path, args.output, model_config, args.cpuinfer_threads, args.threadpool_count, args.input_type, quant_method, merge_to_safetensor
|
||||
args.input_path,
|
||||
args.output,
|
||||
model_config,
|
||||
args.cpuinfer_threads,
|
||||
args.threadpool_count,
|
||||
args.input_type,
|
||||
quant_method,
|
||||
merge_to_safetensor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue