support npu

This commit is contained in:
djw 2025-07-21 12:26:14 +00:00
parent dd0e41b3b8
commit 7d51a13c9b
34 changed files with 14004 additions and 5626 deletions

View file

@ -7,10 +7,19 @@ from typing import Sequence
import os
from enum import IntEnum
import torch
if not torch.xpu.is_available():
try:
import torch_npu
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
if not torch.xpu.is_available() and not use_torch_npu:
import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
if not use_torch_npu:
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from ktransformers.util.custom_gguf import *
from safetensors.torch import save_file
from abc import ABC, abstractmethod
@ -42,6 +51,7 @@ class SafeTensorLoader(ModelLoader):
tensor_device_map: dict
def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path)
def __load_tensor_file_map(self, file_path: str):
@ -84,6 +94,7 @@ class SafeTensorLoader(ModelLoader):
# if not found_safetensor:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str="cpu"):
if translate_name_to_gguf(key) in self.tensor_file_map:
key = translate_name_to_gguf(key)
@ -96,6 +107,7 @@ class SafeTensorLoader(ModelLoader):
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
return tensor.to(device)
def load_experts(self, key: str, device: str="cpu"):
@ -252,20 +264,57 @@ class SafeTensorLoader(ModelLoader):
def has_tensor(self, name: str):
return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map
class W8A8SafeTensorLoader(SafeTensorLoader):
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if "deq_scale" in key:
tensor = torch.from_numpy(
np.frombuffer(tensor.to(torch.float16).to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64))
if "input_scale" in key:
tensor = tensor.to(torch.float16)
if "weight_scale" in key or "weight_offset" in key:
if "ffn" in key:
tensor = tensor.to(torch.float32)
else:
tensor = tensor.to(torch.float16)
if "input_offset" in key:
tensor = tensor.to(torch.int8)
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
return tensor.to(device)
def load_dequantized_tensor(self, key: str, device: str = "cpu"):
tensor = self.load_tensor(key, device)
return tensor
class GGUFLoader(ModelLoader):
tensor_info: dict
gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict
safetensor_loader: SafeTensorLoader
def __init__(self, gguf_path: str):
def __init__(self, gguf_path: str, quantize: str = None):
# Check dir exist
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
if os.path.isfile(gguf_path):
gguf_path = os.path.dirname(gguf_path)
self.safetensor_loader = None
safetensor_loader = SafeTensorLoader(gguf_path)
if quantize == "w8a8_dynamic":
safetensor_loader = W8A8SafeTensorLoader(gguf_path)
else:
safetensor_loader = SafeTensorLoader(gguf_path)
if safetensor_loader.tensor_file_map:
self.safetensor_loader = safetensor_loader
return
self.tensor_info = {}
self.gguf_path = gguf_path