Merge branch 'main' into feat-more-context

This commit is contained in:
Atream 2025-02-22 06:17:39 +00:00
commit 024009675e
18 changed files with 273 additions and 61 deletions

90
.github/workflows/docker-image.yml vendored Normal file
View file

@ -0,0 +1,90 @@
name: DockerHub CI
on:
release:
types: [published]
# push:
# branches:
# - main
env:
DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/ktransformers
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Run tests
run: |
if [ -f docker-compose.test.yml ]; then
docker-compose --file docker-compose.test.yml build
docker-compose --file docker-compose.test.yml run sut
else
docker build . --file Dockerfile
fi
docker_task:
needs: test
name: ${{ matrix.instruct}}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
# for amd64
- {instruct: "FANCY", platform: "linux/amd64"}
- {instruct: "AVX512", platform: "linux/amd64"}
- {instruct: "AVX2", platform: "linux/amd64"}
- {instruct: "NATIVE", platform: "linux/amd64"}
# for arm64
- {instruct: "NATIVE", platform: "linux/arm64"}
steps:
- name: Move Docker data directory
run: |
sudo systemctl stop docker
sudo mkdir -p /mnt/docker
sudo rsync -avz /var/lib/docker/ /mnt/docker
sudo rm -rf /var/lib/docker
sudo ln -s /mnt/docker /var/lib/docker
sudo systemctl start docker
-
name: Set up QEMU
uses: docker/setup-qemu-action@v3
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
-
name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Build and push for amd64
if: matrix.platform == 'linux/amd64'
uses: docker/build-push-action@v6
with:
push: true
platforms: |
linux/amd64
tags: |
${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }}
${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }}
build-args: |
CPU_INSTRUCT=${{ matrix.instruct }}
-
name: Build and push for arm64
if: matrix.platform == 'linux/arm64'
uses: docker/build-push-action@v6
with:
push: true
platforms: |
linux/arm64
tags: |
${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }}
${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }}
build-args: |
CPU_INSTRUCT=${{ matrix.instruct }}

1
.gitignore vendored
View file

@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt
mmlu_result_q4km.json mmlu_result_q4km.json
mmlu_result_q4km.log mmlu_result_q4km.log
ktransformers/tests/mmlu_result_silicon.log ktransformers/tests/mmlu_result_silicon.log
ktransformers/ktransformers_ext/cuda_musa/

View file

@ -11,6 +11,7 @@ EOF
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server
ARG CPU_INSTRUCT=NATIVE
WORKDIR /workspace WORKDIR /workspace
ENV CUDA_HOME /usr/local/cuda ENV CUDA_HOME /usr/local/cuda
COPY --from=web_compile /home/ktransformers /workspace/ktransformers COPY --from=web_compile /home/ktransformers /workspace/ktransformers
@ -28,8 +29,9 @@ git submodule init &&
git submodule update && git submodule update &&
pip install ninja pyproject numpy cpufeature && pip install ninja pyproject numpy cpufeature &&
pip install flash-attn && pip install flash-attn &&
CPU_INSTRUCT=NATIVE KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose && CPU_INSTRUCT=${CPU_INSTRUCT} KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose &&
pip cache purge pip cache purge &&
cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/
EOF EOF
ENTRYPOINT ["tail", "-f", "/dev/null"] ENTRYPOINT ["tail", "-f", "/dev/null"]

View file

@ -226,6 +226,7 @@ Intel is currently the only CPU vendor that supports AMX-like instructions, whic
### Easier ### Easier
* Official Docker images to simplify installation * Official Docker images to simplify installation
* Fix the server integration for web API access * Fix the server integration for web API access
* Fix the local chat only accepting a single line prompt (currently \n begins generating prompt)
* Support for more quantization types, including the highly requested dynamic quantization from unsloth * Support for more quantization types, including the highly requested dynamic quantization from unsloth
Stay tuned for more updates! Stay tuned for more updates!

View file

@ -8,6 +8,20 @@ This document provides the necessary steps to set up and run the web service for
Before you can compile the web code, make sure you have installed [Node.js](https://nodejs.org) version 18.3 or higher Before you can compile the web code, make sure you have installed [Node.js](https://nodejs.org) version 18.3 or higher
Note: The version of Node.js in the Ubuntu or Debian GNU/Linux software repository is too low, causing compilation errors. Users can also install Node.js through the Nodesource repository, provided they uninstall the outdated version first.
```bash
# sudo apt-get remove nodejs npm -y && sudo apt-get autoremove -y
sudo apt-get update -y && sudo apt-get install -y apt-transport-https ca-certificates curl gnupg
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | sudo gpg --dearmor -o /usr/share/keyrings/nodesource.gpg
sudo chmod 644 /usr/share/keyrings/nodesource.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/nodesource.gpg] https://deb.nodesource.com/node_23.x nodistro main" | sudo tee /etc/apt/sources.list.d/nodesource.list
sudo apt-get update -y
sudo apt-get install nodejs -y
```
Once npm is installed, navigate to the `ktransformers/website` directory: Once npm is installed, navigate to the `ktransformers/website` directory:
```bash ```bash

View file

@ -27,11 +27,11 @@ Some preparation:
fi fi
``` ```
- Linux-x86_64 with gcc, g++ and cmake - Linux-x86_64 with gcc, g++ and cmake (using Ubuntu as an example)
```sh ```sh
sudo apt-get update sudo apt-get update
sudo apt-get install gcc g++ cmake ninja-build sudo apt-get install build-essential cmake ninja-build
``` ```
- We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX-3.4.32` - We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX-3.4.32`

View file

@ -160,9 +160,14 @@ DeepSeek 的 MLA 操作符计算密集。虽然全部在 CPU 上运行是可行
5. 为什么选择英特尔 CPU 5. 为什么选择英特尔 CPU
英特尔目前是唯一支持 AMX 类似指令的 CPU 供应商,与仅支持 AVX 的替代方案相比,性能显著更好。 英特尔目前是唯一支持 AMX 类似指令的 CPU 供应商,与仅支持 AVX 的替代方案相比,性能显著更好。
## 常见问题解答 ## 常见问题解答
### R1 不返回思考过程 ### R1 不返回思考过程
注意!如果测试 R1 可能会跳过思考。因此,可以添加参数:`--force_think true`。详细信息在 [常见问题解答](./FAQ.md) 部分中。 <br> 注意!如果测试 R1 可能会跳过思考。因此,可以添加参数:`--force_think true`。详细信息在 [常见问题解答](./FAQ.md) 部分中。 <br>
## 问题
* 修复服务器集成功能以实现网络API访问支持
* 修复本地聊天功能仅支持单行提示输入的问题(目前输入换行符(\n)即开始生成提示)
### 更多常见问题解答 ### 更多常见问题解答
[详见](./FAQ.md) [详见](./FAQ.md)

View file

@ -30,6 +30,8 @@ if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF) option(LLAMA_F16C "llama: enable F16C" OFF)
endif() endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF) option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
# Architecture specific # Architecture specific
# TODO: probably these flags need to be tweaked on some architectures # TODO: probably these flags need to be tweaked on some architectures
@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
if (WIN32) if (WIN32)
include_directories("$ENV{CUDA_PATH}/include") include_directories("$ENV{CUDA_PATH}/include")
elseif (UNIX) elseif (UNIX)
find_package(CUDA REQUIRED) if (KTRANSFORMERS_USE_CUDA)
include_directories("${CUDA_INCLUDE_DIRS}") find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
if (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
else()
set(MUSA_PATH /opt/musa)
endif()
else()
set(MUSA_PATH $ENV{MUSA_PATH})
endif()
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
find_package(MUSAToolkit)
if (MUSAToolkit_FOUND)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
endif()
endif() endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32) if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX) elseif(UNIX)
if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "") if(KTRANSFORMERS_USE_CUDA)
set(ENV{CUDA_HOME} "/usr/local/cuda") if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
set(ENV{CUDA_HOME} "/usr/local/cuda")
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif() endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif() endif()
# Define the USE_NUMA option # Define the USE_NUMA option

View file

@ -17,7 +17,11 @@
#include <queue> #include <queue>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "cuda_runtime.h" #ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#endif
#include "backend.h" #include "backend.h"
#include "task_queue.h" #include "task_queue.h"

View file

@ -0,0 +1,3 @@
## TODO
This directory can be removed after updating the version of `llama.cpp`.

View file

@ -0,0 +1,3 @@
#pragma once
#include <cuda_runtime.h>

View file

@ -0,0 +1,7 @@
#pragma once
#include <musa_runtime.h>
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaStream_t musaStream_t
#define cudaHostFn_t musaHostFn_t

View file

@ -7,7 +7,9 @@
**/ **/
#include "custom_gguf/ops.h" #include "custom_gguf/ops.h"
#ifdef KTRANSFORMERS_USE_CUDA
#include "gptq_marlin/ops.h" #include "gptq_marlin/ops.h"
#endif
// Python bindings // Python bindings
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
@ -53,6 +55,7 @@ PYBIND11_MODULE(KTransformersOps, m) {
}, "Function to dequantize iq4_xs data.", }, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
#ifdef KTRANSFORMERS_USE_CUDA
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),

View file

@ -58,18 +58,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
bias=False, dtype=q_absorb.dtype, device=q_absorb.device) return self.q_absorb, self.out_absorb
self.q_absorb.weight.data = q_absorb
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
self.out_absorb.weight.data = out_absorb
#del self.orig_module.kv_b_proj
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
return q_absorb, out_absorb
def forward_chunck( def forward_chunck(
self, self,
@ -105,7 +97,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None: if self.layer_idx is None:
raise ValueError( raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index." "with a layer index."
) )
@ -129,8 +121,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# compressed_kv [pages, page_size, 1, self.kv_lora_rank] # compressed_kv [pages, page_size, 1, self.kv_lora_rank]
q_absorb, out_absorb = self.get_absorbed() q_absorb, out_absorb = self.get_absorbed()
# if hasattr(self.orig_module, 'kv_b_proj'):
# del self.orig_module.kv_b_proj
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
@ -227,7 +217,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None: if self.layer_idx is None:
raise ValueError( raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index." "with a layer index."
) )
@ -379,7 +369,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None: if self.layer_idx is None:
raise ValueError( raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index." "with a layer index."
) )

View file

@ -9,7 +9,7 @@ flashinfer_enabled = False
try: try:
import flashinfer import flashinfer
flashinfer_enabled = False flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
print("found flashinfer") print("found flashinfer")
except ImportError: except ImportError:

View file

@ -381,13 +381,13 @@ class TransformersInterface(BackendInterfaceBase):
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think
for t in self.prefill(input_ids, self.check_is_new(thread_id)): for t in self.prefill(input_ids, self.check_is_new(thread_id)):
# output think token after prefill done # output think token after prefill done
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t

View file

@ -176,7 +176,7 @@ if __name__ == "__main__":
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file") parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file") parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path") parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") # parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
args = parser.parse_args() args = parser.parse_args()

View file

@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension from setuptools import setup, Extension
from cpufeature.extension import CPUFeature from cpufeature.extension import CPUFeature
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
try:
from torch_musa.utils.simple_porting import SimplePorting
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None
class CpuInstructInfo: class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
@ -49,6 +54,16 @@ class VersionInfo:
) )
FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE" FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE"
def get_musa_bare_metal_version(self, musa_dir):
raw_output = subprocess.run(
[musa_dir + "/bin/mcc", "-v"], check=True,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8")
output = raw_output.split()
release_idx = output.index("version") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return musa_version
def get_cuda_bare_metal_version(self, cuda_dir): def get_cuda_bare_metal_version(self, cuda_dir):
raw_output = subprocess.check_output( raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
@ -58,7 +73,7 @@ class VersionInfo:
cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}" cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return cuda_version return cuda_version
def get_cuda_version_of_torch(self,): def get_cuda_version_of_torch(self):
torch_cuda_version = parse(torch.version.cuda) torch_cuda_version = parse(torch.version.cuda)
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
return cuda_version return cuda_version
@ -128,12 +143,21 @@ class VersionInfo:
return flash_version return flash_version
def get_package_version(self, full_version=False): def get_package_version(self, full_version=False):
flash_version = self.get_flash_version() flash_version = str(self.get_flash_version())
package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}" torch_version = self.get_torch_version()
cpu_instruct = self.get_cpu_instruct()
backend_version = ""
if CUDA_HOME is not None:
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
if full_version: if full_version:
return package_version return package_version
if not VersionInfo.FORCE_BUILD: if not VersionInfo.FORCE_BUILD:
return str(flash_version) return flash_version
return package_version return package_version
@ -218,6 +242,14 @@ class CMakeBuild(BuildExtension):
f"-DPYTHON_EXECUTABLE={sys.executable}", f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
] ]
if CUDA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
elif MUSA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
build_args = [] build_args = []
if "CMAKE_ARGS" in os.environ: if "CMAKE_ARGS" in os.environ:
cmake_args += [ cmake_args += [
@ -288,28 +320,55 @@ class CMakeBuild(BuildExtension):
print("Standard output:", result.stdout) print("Standard output:", result.stdout)
print("Standard error:", result.stderr) print("Standard error:", result.stderr)
subprocess.run( subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True ["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
) )
if CUDA_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
'-DKTRANSFORMERS_USE_CUDA',
]
}
)
elif MUSA_HOME is not None:
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
# Common rules
"at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
}).run()
ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp',
# TODO: Add Marlin support for MUSA.
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
],
extra_compile_args={
'cxx': ['force_mcc'],
'mcc': [
'-O3',
'-DKTRANSFORMERS_USE_MUSA',
'-DTHRUST_IGNORE_CUB_VERSION_CHECK',
]
}
)
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
setup( setup(
version=VersionInfo().get_package_version(), version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[ ext_modules=[
CMakeExtension("cpuinfer_ext"), CMakeExtension("cpuinfer_ext"),
CUDAExtension('KTransformersOps', [ ops_module,
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
}
)
] ]
) )