mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
Fix kt-kernel for new wrapper (#1588)
Some checks are pending
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Book-CI / test-2 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
Some checks are pending
Book-CI / test (push) Waiting to run
Book-CI / test-1 (push) Waiting to run
Book-CI / test-2 (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run
* update README for kt-kernel * style: format C++ and Python code in kt-kernel - Format C++ files: task_queue, ext_bindings, and MoE operators - Format Python utility modules: amx, llamafile, and loader - Improve code readability and consistency
This commit is contained in:
parent
9bc00e587b
commit
94c25626dc
10 changed files with 219 additions and 179 deletions
|
|
@ -6,15 +6,17 @@ import ctypes
|
|||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE
|
||||
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE = None, None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
|
||||
class AMXMoEWrapper(BaseMoEWrapper):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
import torch
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import GGUFLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import MOE
|
||||
|
||||
_HAS_LLAMAFILE_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_LLAMAFILE_SUPPORT = False
|
||||
|
|
@ -14,6 +17,7 @@ except (ImportError, AttributeError):
|
|||
|
||||
from kt_kernel_ext.kvcache import ggml_type
|
||||
|
||||
|
||||
class LlamafileMoEWrapper(BaseMoEWrapper):
|
||||
"""
|
||||
Llamafile-based MoE wrapper implementation.
|
||||
|
|
@ -162,27 +166,17 @@ class LlamafileMoEWrapper(BaseMoEWrapper):
|
|||
)
|
||||
|
||||
if physical_to_logical_map_cpu is None:
|
||||
physical_to_logical_map_cpu = torch.arange(
|
||||
self.num_experts,
|
||||
dtype=torch.int32,
|
||||
device="cpu"
|
||||
)
|
||||
physical_to_logical_map_cpu = torch.arange(self.num_experts, dtype=torch.int32, device="cpu")
|
||||
print(f" Using default identity mapping for {self.num_experts} experts")
|
||||
|
||||
base_key = f"blk.{self.layer_idx}"
|
||||
|
||||
# Load quantized tensors from GGUF
|
||||
gate_data, gate_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_gate_exps.weight"
|
||||
)
|
||||
gate_data, gate_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f"{base_key}.ffn_gate_exps.weight")
|
||||
|
||||
up_data, up_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_up_exps.weight"
|
||||
)
|
||||
up_data, up_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f"{base_key}.ffn_up_exps.weight")
|
||||
|
||||
down_data, down_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(
|
||||
f"{base_key}.ffn_down_exps.weight"
|
||||
)
|
||||
down_data, down_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f"{base_key}.ffn_down_exps.weight")
|
||||
|
||||
# Keep tensors alive
|
||||
self.weights_to_keep = (gate_data, up_data, down_data)
|
||||
|
|
|
|||
|
|
@ -18,35 +18,36 @@ from gguf.gguf_reader import GGUFReader
|
|||
|
||||
class GGMLQuantizationType(IntEnum):
|
||||
"""GGML quantization type enumeration"""
|
||||
F32 = 0
|
||||
F16 = 1
|
||||
Q4_0 = 2
|
||||
Q4_1 = 3
|
||||
Q5_0 = 6
|
||||
Q5_1 = 7
|
||||
Q8_0 = 8
|
||||
Q8_1 = 9
|
||||
Q2_K = 10
|
||||
Q3_K = 11
|
||||
Q4_K = 12
|
||||
Q5_K = 13
|
||||
Q6_K = 14
|
||||
Q8_K = 15
|
||||
|
||||
F32 = 0
|
||||
F16 = 1
|
||||
Q4_0 = 2
|
||||
Q4_1 = 3
|
||||
Q5_0 = 6
|
||||
Q5_1 = 7
|
||||
Q8_0 = 8
|
||||
Q8_1 = 9
|
||||
Q2_K = 10
|
||||
Q3_K = 11
|
||||
Q4_K = 12
|
||||
Q5_K = 13
|
||||
Q6_K = 14
|
||||
Q8_K = 15
|
||||
IQ2_XXS = 16
|
||||
IQ2_XS = 17
|
||||
IQ2_XS = 17
|
||||
IQ3_XXS = 18
|
||||
IQ1_S = 19
|
||||
IQ4_NL = 20
|
||||
IQ3_S = 21
|
||||
IQ2_S = 22
|
||||
IQ4_XS = 23
|
||||
I8 = 24
|
||||
I16 = 25
|
||||
I32 = 26
|
||||
I64 = 27
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
IQ1_S = 19
|
||||
IQ4_NL = 20
|
||||
IQ3_S = 21
|
||||
IQ2_S = 22
|
||||
IQ4_XS = 23
|
||||
I8 = 24
|
||||
I16 = 25
|
||||
I32 = 26
|
||||
I64 = 27
|
||||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
|
||||
|
||||
def translate_name_to_gguf(name):
|
||||
|
|
@ -104,6 +105,7 @@ class SafeTensorLoader:
|
|||
|
||||
Supports loading tensors from .safetensors files with NUMA-sharded expert weights.
|
||||
"""
|
||||
|
||||
tensor_file_map: dict
|
||||
tensor_type_map: dict
|
||||
file_handle_map: dict
|
||||
|
|
@ -257,7 +259,7 @@ class GGUFLoader:
|
|||
self.tensor_file_map = {}
|
||||
self.file_data_map = {}
|
||||
|
||||
if os.path.isfile(gguf_path) and gguf_path.endswith('.gguf'):
|
||||
if os.path.isfile(gguf_path) and gguf_path.endswith(".gguf"):
|
||||
print(f"\n[GGUFLoader] Loading single GGUF file : {os.path.basename(gguf_path)}")
|
||||
self._load_single_file(gguf_path)
|
||||
elif os.path.isdir(gguf_path):
|
||||
|
|
@ -283,24 +285,24 @@ class GGUFLoader:
|
|||
for key, field in reader.fields.items():
|
||||
value = field.parts[field.data[0]]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
value = value.decode("utf-8")
|
||||
elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
|
||||
try:
|
||||
value = bytes(value).decode('utf-8')
|
||||
value = bytes(value).decode("utf-8")
|
||||
except:
|
||||
pass
|
||||
self.metadata[key] = value
|
||||
|
||||
for tensor in reader.tensors:
|
||||
self.tensor_info[tensor.name] = {
|
||||
'shape': list(reversed(tensor.shape)), # Reverse to match PyTorch order
|
||||
'dtype': tensor.tensor_type,
|
||||
'offset': tensor.data_offset,
|
||||
'n_elements': tensor.n_elements,
|
||||
"shape": list(reversed(tensor.shape)), # Reverse to match PyTorch order
|
||||
"dtype": tensor.tensor_type,
|
||||
"offset": tensor.data_offset,
|
||||
"n_elements": tensor.n_elements,
|
||||
}
|
||||
self.tensor_file_map[tensor.name] = file_path
|
||||
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode="r")
|
||||
|
||||
def _load_directory(self, dir_path: str):
|
||||
"""Load all GGUF files from a directory (non-recursive)"""
|
||||
|
|
@ -317,24 +319,24 @@ class GGUFLoader:
|
|||
for key, field in reader.fields.items():
|
||||
value = field.parts[field.data[0]]
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf-8')
|
||||
value = value.decode("utf-8")
|
||||
elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
|
||||
try:
|
||||
value = bytes(value).decode('utf-8')
|
||||
value = bytes(value).decode("utf-8")
|
||||
except:
|
||||
pass
|
||||
self.metadata[key] = value
|
||||
|
||||
for tensor in reader.tensors:
|
||||
self.tensor_info[tensor.name] = {
|
||||
'shape': list(reversed(tensor.shape)),
|
||||
'dtype': tensor.tensor_type,
|
||||
'offset': tensor.data_offset,
|
||||
'n_elements': tensor.n_elements,
|
||||
"shape": list(reversed(tensor.shape)),
|
||||
"dtype": tensor.tensor_type,
|
||||
"offset": tensor.data_offset,
|
||||
"n_elements": tensor.n_elements,
|
||||
}
|
||||
self.tensor_file_map[tensor.name] = file_path
|
||||
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
|
||||
self.file_data_map[file_path] = np.memmap(file_path, mode="r")
|
||||
|
||||
if not found_gguf:
|
||||
raise FileNotFoundError(f"No .gguf files found in directory: {dir_path}")
|
||||
|
|
@ -407,7 +409,7 @@ class GGUFLoader:
|
|||
|
||||
base_key = f"blk.{layer_idx}.ffn_gate_exps.weight"
|
||||
if base_key in self.tensor_info:
|
||||
gate_shape = self.tensor_info[base_key]['shape']
|
||||
gate_shape = self.tensor_info[base_key]["shape"]
|
||||
print(f" Found tensor '{base_key}' with shape: {gate_shape}")
|
||||
|
||||
if len(gate_shape) >= 3:
|
||||
|
|
@ -438,8 +440,9 @@ class GGUFLoader:
|
|||
print(f" Total metadata entries: {len(self.metadata)}")
|
||||
|
||||
if filter_keywords:
|
||||
filtered = {k: v for k, v in self.metadata.items()
|
||||
if any(kw.lower() in k.lower() for kw in filter_keywords)}
|
||||
filtered = {
|
||||
k: v for k, v in self.metadata.items() if any(kw.lower() in k.lower() for kw in filter_keywords)
|
||||
}
|
||||
for k, v in sorted(filtered.items()):
|
||||
print(f" {k}: {v}")
|
||||
else:
|
||||
|
|
@ -477,40 +480,40 @@ class GGUFLoader:
|
|||
file_path = self.tensor_file_map[name]
|
||||
mmap_data = self.file_data_map[file_path]
|
||||
|
||||
offset = info['offset']
|
||||
n_elements = info['n_elements']
|
||||
ggml_type = info['dtype']
|
||||
offset = info["offset"]
|
||||
n_elements = info["n_elements"]
|
||||
ggml_type = info["dtype"]
|
||||
|
||||
GGML_QUANT_SIZES = {
|
||||
GGMLQuantizationType.F32: (1, 4),
|
||||
GGMLQuantizationType.F16: (1, 2),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
|
||||
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q8_0: (32, 2 + 32),
|
||||
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
|
||||
GGMLQuantizationType.Q2_K: (256, 2 + 2 + 256 // 16 + 256 // 4),
|
||||
GGMLQuantizationType.Q3_K: (256, 2 + 256 // 4 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q4_K: (256, 2 + 2 + 256 // 2 + 12),
|
||||
GGMLQuantizationType.Q5_K: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q6_K: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.Q8_K: (256, 4 + 256 + 256 // 8),
|
||||
GGMLQuantizationType.F32: (1, 4),
|
||||
GGMLQuantizationType.F16: (1, 2),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
|
||||
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
|
||||
GGMLQuantizationType.Q8_0: (32, 2 + 32),
|
||||
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
|
||||
GGMLQuantizationType.Q2_K: (256, 2 + 2 + 256 // 16 + 256 // 4),
|
||||
GGMLQuantizationType.Q3_K: (256, 2 + 256 // 4 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q4_K: (256, 2 + 2 + 256 // 2 + 12),
|
||||
GGMLQuantizationType.Q5_K: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
|
||||
GGMLQuantizationType.Q6_K: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.Q8_K: (256, 4 + 256 + 256 // 8),
|
||||
GGMLQuantizationType.IQ2_XXS: (256, 2 + 256 // 4),
|
||||
GGMLQuantizationType.IQ2_XS: (256, 2 + 256 // 4 + 256 // 32),
|
||||
GGMLQuantizationType.IQ2_XS: (256, 2 + 256 // 4 + 256 // 32),
|
||||
GGMLQuantizationType.IQ3_XXS: (256, 2 + 256 // 4 + 256 // 8),
|
||||
GGMLQuantizationType.IQ1_S: (256, 2 + 256 // 8 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
|
||||
GGMLQuantizationType.IQ3_S: (256, 2 + 256 // 4 + 256 // 8 + 256 // 32 + 4),
|
||||
GGMLQuantizationType.IQ2_S: (256, 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + 256 // 2 + 256 // 64),
|
||||
GGMLQuantizationType.I8: (1, 1),
|
||||
GGMLQuantizationType.I16: (1, 2),
|
||||
GGMLQuantizationType.I32: (1, 4),
|
||||
GGMLQuantizationType.I64: (1, 8),
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, 256 // 8 + 256 // 16 + 256 // 32),
|
||||
GGMLQuantizationType.IQ1_S: (256, 2 + 256 // 8 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
|
||||
GGMLQuantizationType.IQ3_S: (256, 2 + 256 // 4 + 256 // 8 + 256 // 32 + 4),
|
||||
GGMLQuantizationType.IQ2_S: (256, 2 + 256 // 4 + 256 // 16),
|
||||
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + 256 // 2 + 256 // 64),
|
||||
GGMLQuantizationType.I8: (1, 1),
|
||||
GGMLQuantizationType.I16: (1, 2),
|
||||
GGMLQuantizationType.I32: (1, 4),
|
||||
GGMLQuantizationType.I64: (1, 8),
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, 256 // 8 + 256 // 16 + 256 // 32),
|
||||
}
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue