mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
Ensure backward compatibility with Torch 2.2
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
parent
4b5991e77e
commit
f88c05a6f1
2 changed files with 32 additions and 25 deletions
|
@ -1,9 +1,9 @@
|
||||||
/**
|
/**
|
||||||
* @Description :
|
* @Description :
|
||||||
* @Author : Azure-Tang, Boxin Zhang
|
* @Author : Azure-Tang, Boxin Zhang
|
||||||
* @Date : 2024-07-25 13:38:30
|
* @Date : 2024-07-25 13:38:30
|
||||||
* @Version : 0.2.2
|
* @Version : 0.2.2
|
||||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
**/
|
**/
|
||||||
|
|
||||||
#include "custom_gguf/ops.h"
|
#include "custom_gguf/ops.h"
|
||||||
|
@ -20,38 +20,45 @@
|
||||||
|
|
||||||
PYBIND11_MODULE(KTransformersOps, m) {
|
PYBIND11_MODULE(KTransformersOps, m) {
|
||||||
|
|
||||||
m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q8_0 data.",
|
}, "Function to dequantize q8_0 data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q6_k data.",
|
}, "Function to dequantize q6_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q5_k data.",
|
}, "Function to dequantize q5_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q4_k data.",
|
}, "Function to dequantize q4_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q3_k data.",
|
}, "Function to dequantize q3_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize q2_k data.",
|
}, "Function to dequantize q2_k data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
|
m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {
|
||||||
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
|
torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);
|
||||||
|
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);
|
||||||
}, "Function to dequantize iq4_xs data.",
|
}, "Function to dequantize iq4_xs data.",
|
||||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
/**
|
/**
|
||||||
* @Description :
|
* @Description :
|
||||||
* @Author : Azure-Tang
|
* @Author : Azure-Tang
|
||||||
* @Date : 2024-07-22 09:27:55
|
* @Date : 2024-07-22 09:27:55
|
||||||
* @Version : 1.0.0
|
* @Version : 1.0.0
|
||||||
* @LastEditors : kkk1nak0
|
* @LastEditors : kkk1nak0
|
||||||
* @LastEditTime : 2024-08-12 03:48:46
|
* @LastEditTime : 2024-08-12 03:48:46
|
||||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
**/
|
**/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
@ -13,10 +13,10 @@
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
|
||||||
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
|
torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue