diff --git a/ktransformers/ktransformers_ext/cuda/binding.cpp b/ktransformers/ktransformers_ext/cuda/binding.cpp index 0b1994d..5bba873 100644 --- a/ktransformers/ktransformers_ext/cuda/binding.cpp +++ b/ktransformers/ktransformers_ext/cuda/binding.cpp @@ -1,9 +1,9 @@ /** - * @Description : + * @Description : * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 * @Version : 0.2.2 - * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "custom_gguf/ops.h" @@ -20,38 +20,45 @@ PYBIND11_MODULE(KTransformersOps, m) { - m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q8_0 data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q6_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q5_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q4_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q3_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q2_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); - m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { - return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); + m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); + return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize iq4_xs data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h index a52db2d..1740cbf 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/ops.h @@ -1,11 +1,11 @@ /** - * @Description : + * @Description : * @Author : Azure-Tang * @Date : 2024-07-22 09:27:55 * @Version : 1.0.0 * @LastEditors : kkk1nak0 * @LastEditTime : 2024-08-12 03:48:46 - * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once @@ -13,10 +13,10 @@ #include #include -torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); -torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype); +torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); +torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);