mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-13 08:39:42 +00:00
Merge branch 'main' into feat-more-context
This commit is contained in:
commit
024009675e
18 changed files with 273 additions and 61 deletions
90
.github/workflows/docker-image.yml
vendored
Normal file
90
.github/workflows/docker-image.yml
vendored
Normal file
|
@ -0,0 +1,90 @@
|
|||
name: DockerHub CI
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
# push:
|
||||
# branches:
|
||||
# - main
|
||||
env:
|
||||
DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/ktransformers
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Run tests
|
||||
run: |
|
||||
if [ -f docker-compose.test.yml ]; then
|
||||
docker-compose --file docker-compose.test.yml build
|
||||
docker-compose --file docker-compose.test.yml run sut
|
||||
else
|
||||
docker build . --file Dockerfile
|
||||
fi
|
||||
|
||||
docker_task:
|
||||
needs: test
|
||||
name: ${{ matrix.instruct}}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
# for amd64
|
||||
- {instruct: "FANCY", platform: "linux/amd64"}
|
||||
- {instruct: "AVX512", platform: "linux/amd64"}
|
||||
- {instruct: "AVX2", platform: "linux/amd64"}
|
||||
- {instruct: "NATIVE", platform: "linux/amd64"}
|
||||
# for arm64
|
||||
- {instruct: "NATIVE", platform: "linux/arm64"}
|
||||
|
||||
steps:
|
||||
- name: Move Docker data directory
|
||||
run: |
|
||||
sudo systemctl stop docker
|
||||
sudo mkdir -p /mnt/docker
|
||||
sudo rsync -avz /var/lib/docker/ /mnt/docker
|
||||
sudo rm -rf /var/lib/docker
|
||||
sudo ln -s /mnt/docker /var/lib/docker
|
||||
sudo systemctl start docker
|
||||
|
||||
-
|
||||
name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
-
|
||||
name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
-
|
||||
name: Build and push for amd64
|
||||
if: matrix.platform == 'linux/amd64'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
push: true
|
||||
platforms: |
|
||||
linux/amd64
|
||||
tags: |
|
||||
${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }}
|
||||
${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }}
|
||||
build-args: |
|
||||
CPU_INSTRUCT=${{ matrix.instruct }}
|
||||
-
|
||||
name: Build and push for arm64
|
||||
if: matrix.platform == 'linux/arm64'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
push: true
|
||||
platforms: |
|
||||
linux/arm64
|
||||
tags: |
|
||||
${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }}
|
||||
${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }}
|
||||
build-args: |
|
||||
CPU_INSTRUCT=${{ matrix.instruct }}
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt
|
|||
mmlu_result_q4km.json
|
||||
mmlu_result_q4km.log
|
||||
ktransformers/tests/mmlu_result_silicon.log
|
||||
ktransformers/ktransformers_ext/cuda_musa/
|
||||
|
|
|
@ -11,6 +11,7 @@ EOF
|
|||
|
||||
|
||||
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server
|
||||
ARG CPU_INSTRUCT=NATIVE
|
||||
WORKDIR /workspace
|
||||
ENV CUDA_HOME /usr/local/cuda
|
||||
COPY --from=web_compile /home/ktransformers /workspace/ktransformers
|
||||
|
@ -28,8 +29,9 @@ git submodule init &&
|
|||
git submodule update &&
|
||||
pip install ninja pyproject numpy cpufeature &&
|
||||
pip install flash-attn &&
|
||||
CPU_INSTRUCT=NATIVE KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose &&
|
||||
pip cache purge
|
||||
CPU_INSTRUCT=${CPU_INSTRUCT} KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose &&
|
||||
pip cache purge &&
|
||||
cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/
|
||||
EOF
|
||||
|
||||
ENTRYPOINT ["tail", "-f", "/dev/null"]
|
|
@ -226,6 +226,7 @@ Intel is currently the only CPU vendor that supports AMX-like instructions, whic
|
|||
### Easier
|
||||
* Official Docker images to simplify installation
|
||||
* Fix the server integration for web API access
|
||||
* Fix the local chat only accepting a single line prompt (currently \n begins generating prompt)
|
||||
* Support for more quantization types, including the highly requested dynamic quantization from unsloth
|
||||
|
||||
Stay tuned for more updates!
|
||||
|
|
|
@ -8,6 +8,20 @@ This document provides the necessary steps to set up and run the web service for
|
|||
|
||||
Before you can compile the web code, make sure you have installed [Node.js](https://nodejs.org) version 18.3 or higher
|
||||
|
||||
Note: The version of Node.js in the Ubuntu or Debian GNU/Linux software repository is too low, causing compilation errors. Users can also install Node.js through the Nodesource repository, provided they uninstall the outdated version first.
|
||||
|
||||
```bash
|
||||
|
||||
# sudo apt-get remove nodejs npm -y && sudo apt-get autoremove -y
|
||||
sudo apt-get update -y && sudo apt-get install -y apt-transport-https ca-certificates curl gnupg
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | sudo gpg --dearmor -o /usr/share/keyrings/nodesource.gpg
|
||||
sudo chmod 644 /usr/share/keyrings/nodesource.gpg
|
||||
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/nodesource.gpg] https://deb.nodesource.com/node_23.x nodistro main" | sudo tee /etc/apt/sources.list.d/nodesource.list
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install nodejs -y
|
||||
|
||||
```
|
||||
|
||||
Once npm is installed, navigate to the `ktransformers/website` directory:
|
||||
|
||||
```bash
|
||||
|
|
|
@ -27,11 +27,11 @@ Some preparation:
|
|||
fi
|
||||
```
|
||||
|
||||
- Linux-x86_64 with gcc, g++ and cmake
|
||||
- Linux-x86_64 with gcc, g++ and cmake (using Ubuntu as an example)
|
||||
|
||||
```sh
|
||||
sudo apt-get update
|
||||
sudo apt-get install gcc g++ cmake ninja-build
|
||||
sudo apt-get install build-essential cmake ninja-build
|
||||
```
|
||||
|
||||
- We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX-3.4.32`
|
||||
|
|
|
@ -160,9 +160,14 @@ DeepSeek 的 MLA 操作符计算密集。虽然全部在 CPU 上运行是可行
|
|||
|
||||
5. 为什么选择英特尔 CPU?
|
||||
英特尔目前是唯一支持 AMX 类似指令的 CPU 供应商,与仅支持 AVX 的替代方案相比,性能显著更好。
|
||||
|
||||
## 常见问题解答
|
||||
### R1 不返回思考过程
|
||||
注意!如果测试 R1 可能会跳过思考。因此,可以添加参数:`--force_think true`。详细信息在 [常见问题解答](./FAQ.md) 部分中。 <br>
|
||||
|
||||
## 问题
|
||||
* 修复服务器集成功能以实现网络API访问支持
|
||||
* 修复本地聊天功能仅支持单行提示输入的问题(目前输入换行符(\n)即开始生成提示)
|
||||
|
||||
### 更多常见问题解答
|
||||
[详见](./FAQ.md)
|
||||
|
|
|
@ -30,6 +30,8 @@ if (NOT MSVC)
|
|||
option(LLAMA_F16C "llama: enable F16C" OFF)
|
||||
endif()
|
||||
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
|
||||
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
|
||||
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
|
@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
|||
if (WIN32)
|
||||
include_directories("$ENV{CUDA_PATH}/include")
|
||||
elseif (UNIX)
|
||||
if (KTRANSFORMERS_USE_CUDA)
|
||||
find_package(CUDA REQUIRED)
|
||||
include_directories("${CUDA_INCLUDE_DIRS}")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
||||
endif()
|
||||
|
||||
if (KTRANSFORMERS_USE_MUSA)
|
||||
if (NOT EXISTS $ENV{MUSA_PATH})
|
||||
if (NOT EXISTS /opt/musa)
|
||||
set(MUSA_PATH /usr/local/musa)
|
||||
else()
|
||||
set(MUSA_PATH /opt/musa)
|
||||
endif()
|
||||
else()
|
||||
set(MUSA_PATH $ENV{MUSA_PATH})
|
||||
endif()
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
|
||||
|
||||
find_package(MUSAToolkit)
|
||||
if (MUSAToolkit_FOUND)
|
||||
message(STATUS "MUSA Toolkit found")
|
||||
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
||||
|
@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
|
|||
if(WIN32)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
|
||||
elseif(UNIX)
|
||||
if(KTRANSFORMERS_USE_CUDA)
|
||||
if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
|
||||
set(ENV{CUDA_HOME} "/usr/local/cuda")
|
||||
endif()
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
|
||||
endif()
|
||||
if(KTRANSFORMERS_USE_MUSA)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Define the USE_NUMA option
|
||||
|
|
|
@ -17,7 +17,11 @@
|
|||
#include <queue>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include "cuda_runtime.h"
|
||||
#ifdef KTRANSFORMERS_USE_CUDA
|
||||
#include "vendors/cuda.h"
|
||||
#elif KTRANSFORMERS_USE_MUSA
|
||||
#include "vendors/musa.h"
|
||||
#endif
|
||||
|
||||
#include "backend.h"
|
||||
#include "task_queue.h"
|
||||
|
|
3
ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
vendored
Normal file
3
ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
## TODO
|
||||
|
||||
This directory can be removed after updating the version of `llama.cpp`.
|
3
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
vendored
Normal file
3
ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
7
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
vendored
Normal file
7
ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <musa_runtime.h>
|
||||
|
||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||
#define cudaStream_t musaStream_t
|
||||
#define cudaHostFn_t musaHostFn_t
|
|
@ -7,7 +7,9 @@
|
|||
**/
|
||||
|
||||
#include "custom_gguf/ops.h"
|
||||
#ifdef KTRANSFORMERS_USE_CUDA
|
||||
#include "gptq_marlin/ops.h"
|
||||
#endif
|
||||
// Python bindings
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
@ -53,6 +55,7 @@ PYBIND11_MODULE(KTransformersOps, m) {
|
|||
}, "Function to dequantize iq4_xs data.",
|
||||
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
|
||||
|
||||
#ifdef KTRANSFORMERS_USE_CUDA
|
||||
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
|
||||
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
|
||||
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
|
||||
|
|
|
@ -58,18 +58,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
|
||||
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
|
||||
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
|
||||
self.q_absorb.weight.data = q_absorb
|
||||
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
|
||||
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
|
||||
self.out_absorb.weight.data = out_absorb
|
||||
#del self.orig_module.kv_b_proj
|
||||
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
return q_absorb, out_absorb
|
||||
self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
|
||||
return self.q_absorb, self.out_absorb
|
||||
|
||||
def forward_chunck(
|
||||
self,
|
||||
|
@ -105,7 +97,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
@ -129,8 +121,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
# if hasattr(self.orig_module, 'kv_b_proj'):
|
||||
# del self.orig_module.kv_b_proj
|
||||
|
||||
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
||||
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
||||
|
@ -227,7 +217,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
@ -379,7 +369,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@ flashinfer_enabled = False
|
|||
|
||||
try:
|
||||
import flashinfer
|
||||
flashinfer_enabled = False
|
||||
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
|
||||
print("found flashinfer")
|
||||
|
||||
except ImportError:
|
||||
|
|
|
@ -381,13 +381,13 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
# output think token after prefill done
|
||||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
|
|
|
@ -176,7 +176,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
|
||||
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
|
||||
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
|
||||
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
97
setup.py
97
setup.py
|
@ -30,6 +30,11 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
|||
from setuptools import setup, Extension
|
||||
from cpufeature.extension import CPUFeature
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
try:
|
||||
from torch_musa.utils.simple_porting import SimplePorting
|
||||
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
|
||||
except ImportError:
|
||||
MUSA_HOME=None
|
||||
|
||||
class CpuInstructInfo:
|
||||
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
|
||||
|
@ -49,6 +54,16 @@ class VersionInfo:
|
|||
)
|
||||
FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE"
|
||||
|
||||
def get_musa_bare_metal_version(self, musa_dir):
|
||||
raw_output = subprocess.run(
|
||||
[musa_dir + "/bin/mcc", "-v"], check=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8")
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("version") + 1
|
||||
bare_metal_version = parse(output[release_idx].split(",")[0])
|
||||
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
|
||||
return musa_version
|
||||
|
||||
def get_cuda_bare_metal_version(self, cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
|
@ -58,7 +73,7 @@ class VersionInfo:
|
|||
cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
|
||||
return cuda_version
|
||||
|
||||
def get_cuda_version_of_torch(self,):
|
||||
def get_cuda_version_of_torch(self):
|
||||
torch_cuda_version = parse(torch.version.cuda)
|
||||
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
||||
return cuda_version
|
||||
|
@ -128,12 +143,21 @@ class VersionInfo:
|
|||
return flash_version
|
||||
|
||||
def get_package_version(self, full_version=False):
|
||||
flash_version = self.get_flash_version()
|
||||
package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}"
|
||||
flash_version = str(self.get_flash_version())
|
||||
torch_version = self.get_torch_version()
|
||||
cpu_instruct = self.get_cpu_instruct()
|
||||
backend_version = ""
|
||||
if CUDA_HOME is not None:
|
||||
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
|
||||
elif MUSA_HOME is not None:
|
||||
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
|
||||
if full_version:
|
||||
return package_version
|
||||
if not VersionInfo.FORCE_BUILD:
|
||||
return str(flash_version)
|
||||
return flash_version
|
||||
return package_version
|
||||
|
||||
|
||||
|
@ -218,6 +242,14 @@ class CMakeBuild(BuildExtension):
|
|||
f"-DPYTHON_EXECUTABLE={sys.executable}",
|
||||
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
|
||||
]
|
||||
|
||||
if CUDA_HOME is not None:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
|
||||
elif MUSA_HOME is not None:
|
||||
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
|
||||
build_args = []
|
||||
if "CMAKE_ARGS" in os.environ:
|
||||
cmake_args += [
|
||||
|
@ -288,28 +320,55 @@ class CMakeBuild(BuildExtension):
|
|||
print("Standard output:", result.stdout)
|
||||
print("Standard error:", result.stderr)
|
||||
subprocess.run(
|
||||
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
||||
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
|
||||
)
|
||||
|
||||
if CUDA_HOME is not None:
|
||||
ops_module = CUDAExtension('KTransformersOps', [
|
||||
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
|
||||
'ktransformers/ktransformers_ext/cuda/binding.cpp',
|
||||
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
|
||||
],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-Xcompiler', '-fPIC',
|
||||
'-DKTRANSFORMERS_USE_CUDA',
|
||||
]
|
||||
}
|
||||
)
|
||||
elif MUSA_HOME is not None:
|
||||
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
|
||||
# Common rules
|
||||
"at::cuda": "at::musa",
|
||||
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
|
||||
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
|
||||
}).run()
|
||||
ops_module = MUSAExtension('KTransformersOps', [
|
||||
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
|
||||
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp',
|
||||
# TODO: Add Marlin support for MUSA.
|
||||
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
|
||||
],
|
||||
extra_compile_args={
|
||||
'cxx': ['force_mcc'],
|
||||
'mcc': [
|
||||
'-O3',
|
||||
'-DKTRANSFORMERS_USE_MUSA',
|
||||
'-DTHRUST_IGNORE_CUB_VERSION_CHECK',
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
|
||||
|
||||
setup(
|
||||
version=VersionInfo().get_package_version(),
|
||||
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
|
||||
ext_modules=[
|
||||
CMakeExtension("cpuinfer_ext"),
|
||||
CUDAExtension('KTransformersOps', [
|
||||
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
|
||||
'ktransformers/ktransformers_ext/cuda/binding.cpp',
|
||||
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
|
||||
],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3'],
|
||||
'nvcc': [
|
||||
'-O3',
|
||||
'--use_fast_math',
|
||||
'-Xcompiler', '-fPIC',
|
||||
]
|
||||
}
|
||||
)
|
||||
ops_module,
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue