optimize GPU

This commit is contained in:
Atream 2025-02-21 05:06:57 +00:00
parent cf4da5fd47
commit 7e1fe256c8
8 changed files with 677 additions and 156 deletions

View file

@ -20,19 +20,19 @@
PYBIND11_MODULE(KTransformersOps, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),