mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
298 lines
No EOL
14 KiB
Python
298 lines
No EOL
14 KiB
Python
from abc import abstractmethod
|
|
|
|
import torch
|
|
import torch_npu
|
|
import torch.distributed as dist
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
|
from ktransformers.operators.linear import KLinearBase, LINEAR_MAP
|
|
from ktransformers.util.ascend.ascend_utils import (
|
|
get_safetensors_cut_weight,
|
|
get_tensor_parallel_size,
|
|
get_tensor_parallel_group
|
|
)
|
|
from ktransformers.util import utils
|
|
from ktransformers.util.custom_gguf import GGUFLoader
|
|
from ktransformers.util.utils import InferenceState
|
|
|
|
|
|
class KLinearW8A8(KLinearBase):
|
|
def __init__(
|
|
self,
|
|
key: str,
|
|
gguf_loader: GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module = None,
|
|
device: str = "cuda",
|
|
**kwargs,
|
|
):
|
|
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
|
|
|
def load_weight(self, override_key: str | None = None, device: str | None = None):
|
|
if override_key is not None:
|
|
keys = override_key
|
|
else:
|
|
keys = [self.key]
|
|
fake_tensor = torch.tensor([1])
|
|
for key in keys:
|
|
if device is None:
|
|
device = utils.CUR_DEVICE
|
|
if key + ".weight" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
|
if key + ".deq_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
|
qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
|
|
deq_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.deq_scale")
|
|
quant_bias = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.quant_bias")
|
|
input_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_scale")
|
|
input_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_offset")
|
|
tensors = (qweight, deq_scale, quant_bias, input_scale, input_offset)
|
|
return tensors
|
|
elif key + ".weight_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
|
if key.endswith("ffn_gate_shexp"):
|
|
parts = key.split(".")
|
|
layer = parts[1]
|
|
gate_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight")
|
|
gate_weight = get_safetensors_cut_weight(self.key, gate_weight).t()
|
|
up_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight")
|
|
up_weight = get_safetensors_cut_weight(self.key, up_weight).t()
|
|
gate_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_scale")
|
|
gate_scale = get_safetensors_cut_weight(self.key, gate_scale)
|
|
up_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_scale")
|
|
up_scale = get_safetensors_cut_weight(self.key, up_scale)
|
|
gate_up_weight = torch.cat((gate_weight, up_weight), 1)
|
|
gate_up_scale = torch.cat((gate_scale, up_scale), 0)
|
|
gate_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_offset")
|
|
up_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_offset")
|
|
gate_up_offset = torch.cat((gate_offset, up_offset), 0)
|
|
tensors = (gate_up_weight, gate_up_scale, gate_up_offset)
|
|
elif key.endswith("ffn_up_shexp"):
|
|
return fake_tensor
|
|
else:
|
|
qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
|
|
weight_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_scale")
|
|
weight_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_offset")
|
|
tensors = (qweight, weight_scale, weight_offset)
|
|
return tensors
|
|
else:
|
|
weight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
|
|
weight = get_safetensors_cut_weight(self.key, weight)
|
|
return weight
|
|
else:
|
|
raise FileNotFoundError(f"Weight file not found for key {key}")
|
|
|
|
@abstractmethod
|
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = "cuda"):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def unload(self):
|
|
pass
|
|
|
|
|
|
class KLinearTorchW8A8A2(KLinearW8A8):
|
|
def __init__(
|
|
self,
|
|
key: str,
|
|
gguf_loader: GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module = None,
|
|
device: str = "cuda",
|
|
**kwargs,
|
|
):
|
|
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
|
self.has_bias = False
|
|
self.dtype = torch.get_default_dtype()
|
|
self.weight = None
|
|
self.input_scale = None
|
|
self.input_offset = None
|
|
self.quant_bias = None
|
|
self.deq_scale = None
|
|
self.weight_scale = None
|
|
self.weight_offset = None
|
|
|
|
def forward(self, x: torch.Tensor, bsz_tensor) -> torch.Tensor:
|
|
tp = get_tensor_parallel_size()
|
|
if tp == 1:
|
|
out = torch.zeros((x.shape[0], x.shape[1], self.weight.shape[-1]), dtype=torch.float16, device=x.device)
|
|
torch_npu._npu_matmul_pp(x, self.weight, out)
|
|
else:
|
|
tp_size = get_tensor_parallel_size()
|
|
tp_group = get_tensor_parallel_group()
|
|
batch_size = x.shape[0]
|
|
seq_length = x.shape[1]
|
|
lm_sep_size = tp_size
|
|
lm_head_group = tp_group
|
|
gathered_list = [torch.empty_like(x) for _ in range(lm_sep_size)]
|
|
dist.all_gather(gathered_list, x, group=lm_head_group)
|
|
input_full = torch.stack(gathered_list, dim=0)
|
|
input_full = input_full.squeeze(dim=1)
|
|
torch_npu.npu_format_cast_(input_full, 2)
|
|
local_logits = torch.zeros((input_full.shape[0], input_full.shape[1], self.weight.shape[-1]),
|
|
dtype=torch.float16, device=input_full.device)
|
|
torch_npu._npu_matmul_pp(input_full, self.weight, local_logits)
|
|
local_logits_transpose = local_logits.transpose(2, 1).reshape(-1, batch_size * seq_length)
|
|
del local_logits
|
|
output_tensor = torch.empty_like(local_logits_transpose)
|
|
dist.all_to_all_single(output_tensor, local_logits_transpose, group=lm_head_group)
|
|
del local_logits_transpose
|
|
output_tensor = output_tensor.transpose(1, 0)
|
|
out = output_tensor.view(batch_size, seq_length, -1)
|
|
del output_tensor
|
|
return out
|
|
|
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):
|
|
if device is None:
|
|
device = self.device
|
|
device = utils.CUR_DEVICE
|
|
if w is None:
|
|
w = self.load_weight()
|
|
if isinstance(w, nn.Parameter):
|
|
try:
|
|
self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T.contiguous()
|
|
except:
|
|
self.weight = w.to(dtype=self.dtype).T.contiguous()
|
|
self.weight = self.weight.to(device)
|
|
if self.has_bias:
|
|
self.bias = self.bias.to(device)
|
|
elif isinstance(w, tuple):
|
|
w_list = list(w)
|
|
if len(w_list) == 3:
|
|
self.weight = w_list[0]
|
|
self.weight_scale = w_list[1].view(-1)
|
|
self.weight_offset = w_list[2]
|
|
self.weight = self.weight.to(utils.CUR_DEVICE)
|
|
self.weight_scale = self.weight_scale.to(utils.CUR_DEVICE)
|
|
if self.key.endswith("ffn_gate_shexp") is not True:
|
|
self.weight = get_safetensors_cut_weight(self.key, self.weight).t()
|
|
weight_scale = get_safetensors_cut_weight(self.key, self.weight_scale)
|
|
self.weight_scale = weight_scale.clone()
|
|
del weight_scale
|
|
self.weight_offset = self.weight_offset.to(utils.CUR_DEVICE)
|
|
else:
|
|
for i in range(len(w_list)):
|
|
w_list[i] = get_safetensors_cut_weight(self.key, w_list[i])
|
|
w_list[i] = w_list[i].to(utils.CUR_DEVICE)
|
|
self.weight = w_list[0]
|
|
self.deq_scale = w_list[1]
|
|
self.quant_bias = w_list[2]
|
|
if "attn_output" in self.key or "ffn_down" in self.key:
|
|
if torch.distributed.get_rank(get_tensor_parallel_group()) != 0:
|
|
self.quant_bias = torch.zeros_like(self.quant_bias, dtype=self.quant_bias.dtype, device=self.quant_bias.device)
|
|
self.input_scale = w_list[3]
|
|
self.input_offset = w_list[4]
|
|
elif isinstance(w, torch.Tensor):
|
|
self.weight = w.T.contiguous()
|
|
self.weight.to(device)
|
|
if "kv_b" not in self.key:
|
|
self.weight = self.weight.to(device)
|
|
torch_npu.npu_format_cast_(self.weight, 29)
|
|
else:
|
|
raise ValueError(f"Invalid weight type {self.key=} {type(w)=}")
|
|
|
|
def unload(self):
|
|
if self.weight is not None:
|
|
self.weight = None
|
|
if self.has_bias:
|
|
self.bias = None
|
|
self.input_scale = None
|
|
self.input_offset = None
|
|
self.quant_bias = None
|
|
self.deq_scale = None
|
|
self.weight_scale = None
|
|
self.weight_offset = None
|
|
|
|
|
|
LINEAR_MAP["KLinearTorchW8A8A2"] = KLinearTorchW8A8A2
|
|
|
|
|
|
class KTransformersLinearW8A8A2(BaseInjectedModule, KLinearW8A8):
|
|
def __init__(
|
|
self,
|
|
key: str,
|
|
gguf_loader: GGUFLoader,
|
|
config: PretrainedConfig,
|
|
orig_module: nn.Module,
|
|
generate_device: str = "cuda",
|
|
generate_op: str | None = "KLinearMarlin",
|
|
prefill_device: str = "cuda",
|
|
prefill_op: str | None = "KLinearTorch",
|
|
**kwargs,
|
|
):
|
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
|
KLinearW8A8.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
|
# build all the linear operators
|
|
if prefill_op is not None:
|
|
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
|
|
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
|
else:
|
|
self.prefill_linear = None
|
|
|
|
if generate_op is not None:
|
|
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
|
|
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
|
else:
|
|
self.generate_linear = None
|
|
self.mode = InferenceState.UNLOAD
|
|
|
|
def forward(self, x, bsz_tensor=None):
|
|
if self.mode == InferenceState.PREFILL:
|
|
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
|
y = self.prefill_linear.forward(x, bsz_tensor)
|
|
else:
|
|
assert self.generate_linear is not None, "gpu linear is not initialized"
|
|
y = self.generate_linear.forward(x, bsz_tensor)
|
|
return y
|
|
|
|
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
|
if not mode:
|
|
mode = InferenceState.GENERATE
|
|
# load to device
|
|
if mode == InferenceState.PREFILL:
|
|
self.generate_linear.unload()
|
|
self.prefill_linear.load(w=w)
|
|
self.device = self.prefill_linear.device
|
|
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
|
|
self.input_scale = self.prefill_linear.input_scale
|
|
self.input_offset = self.prefill_linear.input_offset
|
|
self.quant_bias = self.prefill_linear.quant_bias
|
|
self.deq_scale = self.prefill_linear.deq_scale
|
|
self.weight_scale = self.prefill_linear.weight_scale
|
|
self.weight_offset = self.prefill_linear.weight_offset
|
|
elif mode == InferenceState.GENERATE:
|
|
self.prefill_linear.unload()
|
|
self.generate_linear.load(w=w)
|
|
self.device = self.generate_linear.device
|
|
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
|
|
self.input_scale = self.generate_linear.input_scale
|
|
self.input_offset = self.generate_linear.input_offset
|
|
self.quant_bias = self.generate_linear.quant_bias
|
|
self.deq_scale = self.generate_linear.deq_scale
|
|
self.weight_scale = self.generate_linear.weight_scale
|
|
self.weight_offset = self.generate_linear.weight_offset
|
|
elif mode == InferenceState.UNLOAD:
|
|
self.prefill_linear.unload()
|
|
self.generate_linear.unload()
|
|
self.device = "cpu"
|
|
else:
|
|
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
|
self.mode = mode
|
|
|
|
def unload(self):
|
|
if self.prefill_linear is not None:
|
|
self.prefill_linear.unload()
|
|
if self.generate_linear is not None:
|
|
self.generate_linear.unload()
|
|
self.device = self.generate_linear.device
|
|
|
|
def set_inference_mode(self, mode: InferenceState):
|
|
if not mode:
|
|
mode = InferenceState.GENERATE
|
|
if mode == InferenceState.GENERATE:
|
|
self.load(mode=InferenceState.GENERATE)
|
|
elif mode == InferenceState.PREFILL:
|
|
self.load(mode=InferenceState.PREFILL)
|
|
elif mode == InferenceState.UNLOAD:
|
|
self.unload()
|
|
else:
|
|
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") |