mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-13 00:29:59 +00:00
[feature] support q2_k & q3_k dequantize on gpu
This commit is contained in:
parent
650c368c18
commit
7c4cb520bd
5 changed files with 161 additions and 12 deletions
|
@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
|
|||
Date : 2024-07-26 08:48:54
|
||||
Version : 1.0.0
|
||||
LastEditors : kkk1nak0
|
||||
LastEditTime : 2024-08-09 08:03:44
|
||||
LastEditTime : 2024-08-12 07:21:55
|
||||
Adapted from https://github.com/99991/pygguf/blob/main/gguf.py
|
||||
Copyright (c) 2023-2024 The ggml authors
|
||||
Copyright (c) 2024 Thomas Germer
|
||||
|
@ -390,8 +390,14 @@ def dequantize_q2_k(data):
|
|||
|
||||
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
|
||||
|
||||
def dequantize_q2_k_gpu(data):
|
||||
raise NotImplementedError()
|
||||
def dequantize_q2_k_gpu(data, device:str ="cuda"):
|
||||
block_size = GGML_BLOCK_SIZES["Q2_K"]
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
device = torch.device(device)
|
||||
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
|
||||
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q2_k(data, block_size, device)
|
||||
|
||||
def dequantize_q3_k(data):
|
||||
# C implementation
|
||||
|
@ -435,8 +441,14 @@ def dequantize_q3_k(data):
|
|||
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q3_k_gpu(data):
|
||||
raise NotImplementedError()
|
||||
def dequantize_q3_k_gpu(data, device:str ="cuda"):
|
||||
block_size = GGML_BLOCK_SIZES["Q3_K"]
|
||||
data = np.frombuffer(data, dtype=data.dtype)
|
||||
device = torch.device(device)
|
||||
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
|
||||
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
|
||||
data = torch.from_numpy(data)
|
||||
return KTransformersOps.dequantize_q3_k(data, block_size, device)
|
||||
|
||||
def dequantize_q4_k(data):
|
||||
# C implementation
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue