mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
86 lines
No EOL
3.3 KiB
Python
86 lines
No EOL
3.3 KiB
Python
import struct
|
|
import warnings
|
|
import numpy as np
|
|
import re
|
|
import numpy.typing as npt
|
|
from typing import Sequence
|
|
import os
|
|
from enum import IntEnum
|
|
import torch
|
|
import KTransformersOps
|
|
from safetensors import safe_open
|
|
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
|
from safetensors.torch import save_file
|
|
|
|
class SafeTensorLoader:
|
|
tensor_file_map = {}
|
|
tensor_type_map = {}
|
|
file_handle_map = {}
|
|
|
|
def __init__(self, file_path: str):
|
|
self.__load_tensor_file_map(file_path)
|
|
|
|
def __load_tensor_file_map(self, file_path: str):
|
|
# 处理传入路径,确保是文件夹路径
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"Path not found: {file_path}")
|
|
if os.path.isfile(file_path):
|
|
folder_path = os.path.dirname(file_path)
|
|
else:
|
|
folder_path = file_path
|
|
|
|
found_safetensor = False
|
|
for root, _, files in os.walk(folder_path):
|
|
files = sorted(files)
|
|
for file in files:
|
|
if file.endswith(".safetensors"):
|
|
found_safetensor = True
|
|
file_path = os.path.join(root, file)
|
|
if file not in self.file_handle_map:
|
|
try:
|
|
handle = safe_open(file_path, framework="pt")
|
|
self.file_handle_map[file] = handle
|
|
except Exception as e:
|
|
print(f"Error opening Safetensor file {file_path}: {e}")
|
|
continue
|
|
|
|
f = self.file_handle_map.get(file)
|
|
if f is None:
|
|
continue
|
|
try:
|
|
for key in f.keys():
|
|
self.tensor_file_map[key] = file
|
|
except Exception as e:
|
|
print(f"Error reading Safetensor file {file_path}: {e}")
|
|
|
|
# if not found_safetensor:
|
|
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
|
|
|
|
def load_tensor(self, key: str, device: str="cpu"):
|
|
if key not in self.tensor_file_map:
|
|
raise KeyError(f"Key {key} not found in Safetensor files")
|
|
file = self.tensor_file_map[key]
|
|
f = self.file_handle_map.get(file)
|
|
if f is None:
|
|
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
|
tensor = f.get_tensor(key)
|
|
return tensor.to(device)
|
|
|
|
def close_all_handles(self):
|
|
for handle in self.file_handle_map.values():
|
|
handle.close()
|
|
self.file_handle_map.clear()
|
|
|
|
def load_dequantized_tensor(self, key:str, device: str="cpu"):
|
|
if key not in self.tensor_file_map:
|
|
raise KeyError(f"Key {key} not found in Safetensor files")
|
|
file = self.tensor_file_map[key]
|
|
f = self.file_handle_map.get(file)
|
|
if f is None:
|
|
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
|
tensor = f.get_tensor(key).to(device)
|
|
if key.endswith(".weight"):
|
|
if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
|
|
weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
|
|
tensor = weight_dequant(tensor, weight_scale_inv)
|
|
return tensor.to(device) |