mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-16 10:09:42 +00:00
[feature] support python 310 and multi instruction
This commit is contained in:
parent
25620829ce
commit
112cb3c962
5 changed files with 188 additions and 8 deletions
|
@ -1 +1 @@
|
|||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.1"
|
|
@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.16)
|
|||
project(cpuinfer_ext VERSION 0.1.0)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
include(CheckCXXCompilerFlag)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
@ -10,6 +10,27 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|||
|
||||
option(LLAMA_NATIVE "llama: enable -march=native flag" ON)
|
||||
|
||||
# instruction set specific
|
||||
if (LLAMA_NATIVE)
|
||||
set(INS_ENB OFF)
|
||||
else()
|
||||
set(INS_ENB ON)
|
||||
endif()
|
||||
|
||||
option(LLAMA_AVX "llama: enable AVX" OFF)
|
||||
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
|
||||
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
||||
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
||||
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
||||
option(LLAMA_FMA "llama: enable FMA" OFF)
|
||||
# in MSVC F16C is implied with AVX2/AVX512
|
||||
if (NOT MSVC)
|
||||
option(LLAMA_F16C "llama: enable F16C" OFF)
|
||||
endif()
|
||||
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
|
||||
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
|
||||
|
||||
|
||||
# Architecture specific
|
||||
# TODO: probably these flags need to be tweaked on some architectures
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
|
@ -102,6 +123,20 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
||||
endif()
|
||||
if (LLAMA_AVX512_FANCY_SIMD)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VL__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VL__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BW__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BW__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512DQ__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512DQ__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
||||
endif()
|
||||
if (LLAMA_AVX512_BF16)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
|
||||
endif()
|
||||
elseif (LLAMA_AVX2)
|
||||
list(APPEND ARCH_FLAGS /arch:AVX2)
|
||||
elseif (LLAMA_AVX)
|
||||
|
@ -133,6 +168,15 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||
if (LLAMA_AVX512_VNNI)
|
||||
list(APPEND ARCH_FLAGS -mavx512vnni)
|
||||
endif()
|
||||
if (LLAMA_AVX512_FANCY_SIMD)
|
||||
list(APPEND ARCH_FLAGS -mavx512vl)
|
||||
list(APPEND ARCH_FLAGS -mavx512bw)
|
||||
list(APPEND ARCH_FLAGS -mavx512dq)
|
||||
list(APPEND ARCH_FLAGS -mavx512vnni)
|
||||
endif()
|
||||
if (LLAMA_AVX512_BF16)
|
||||
list(APPEND ARCH_FLAGS -mavx512bf16)
|
||||
endif()
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
message(STATUS "PowerPC detected")
|
||||
|
|
100
ktransformers/ktransformers_ext/cmake/FindSIMD.cmake
Normal file
100
ktransformers/ktransformers_ext/cmake/FindSIMD.cmake
Normal file
|
@ -0,0 +1,100 @@
|
|||
include(CheckCSourceRuns)
|
||||
|
||||
set(AVX_CODE "
|
||||
#include <immintrin.h>
|
||||
int main()
|
||||
{
|
||||
__m256 a;
|
||||
a = _mm256_set1_ps(0);
|
||||
return 0;
|
||||
}
|
||||
")
|
||||
|
||||
set(AVX512_CODE "
|
||||
#include <immintrin.h>
|
||||
int main()
|
||||
{
|
||||
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0);
|
||||
__m512i b = a;
|
||||
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
|
||||
return 0;
|
||||
}
|
||||
")
|
||||
|
||||
set(AVX2_CODE "
|
||||
#include <immintrin.h>
|
||||
int main()
|
||||
{
|
||||
__m256i a = {0};
|
||||
a = _mm256_abs_epi16(a);
|
||||
__m256i x;
|
||||
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
|
||||
return 0;
|
||||
}
|
||||
")
|
||||
|
||||
set(FMA_CODE "
|
||||
#include <immintrin.h>
|
||||
int main()
|
||||
{
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
const __m256 d = _mm256_setzero_ps();
|
||||
const __m256 p = _mm256_setzero_ps();
|
||||
acc = _mm256_fmadd_ps( d, p, acc );
|
||||
return 0;
|
||||
}
|
||||
")
|
||||
|
||||
macro(check_sse type flags)
|
||||
set(__FLAG_I 1)
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
foreach (__FLAG ${flags})
|
||||
if (NOT ${type}_FOUND)
|
||||
set(CMAKE_REQUIRED_FLAGS ${__FLAG})
|
||||
check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
|
||||
if (HAS_${type}_${__FLAG_I})
|
||||
set(${type}_FOUND TRUE CACHE BOOL "${type} support")
|
||||
set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
|
||||
endif()
|
||||
math(EXPR __FLAG_I "${__FLAG_I}+1")
|
||||
endif()
|
||||
endforeach()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
|
||||
if (NOT ${type}_FOUND)
|
||||
set(${type}_FOUND FALSE CACHE BOOL "${type} support")
|
||||
set(${type}_FLAGS "" CACHE STRING "${type} flags")
|
||||
endif()
|
||||
|
||||
mark_as_advanced(${type}_FOUND ${type}_FLAGS)
|
||||
endmacro()
|
||||
|
||||
# flags are for MSVC only!
|
||||
check_sse("AVX" " ;/arch:AVX")
|
||||
if (NOT ${AVX_FOUND})
|
||||
set(LLAMA_AVX OFF)
|
||||
else()
|
||||
set(LLAMA_AVX ON)
|
||||
endif()
|
||||
|
||||
check_sse("AVX2" " ;/arch:AVX2")
|
||||
check_sse("FMA" " ;/arch:AVX2")
|
||||
if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
|
||||
set(LLAMA_AVX2 OFF)
|
||||
else()
|
||||
set(LLAMA_AVX2 ON)
|
||||
endif()
|
||||
|
||||
check_sse("AVX512" " ;/arch:AVX512")
|
||||
if (NOT ${AVX512_FOUND})
|
||||
set(LLAMA_AVX512 OFF)
|
||||
else()
|
||||
set(LLAMA_AVX512 ON)
|
||||
endif()
|
|
@ -1,7 +1,7 @@
|
|||
[build-system]
|
||||
requires = [
|
||||
"setuptools",
|
||||
"torch == 2.3.1",
|
||||
"torch >= 2.3.0",
|
||||
"ninja",
|
||||
"packaging"
|
||||
]
|
||||
|
@ -29,7 +29,7 @@ dependencies = [
|
|||
"fire"
|
||||
]
|
||||
|
||||
requires-python = ">=3.11"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
authors = [
|
||||
{name = "KVCache.AI", email = "zhang.mingxing@outlook.com"}
|
||||
|
@ -50,6 +50,7 @@ keywords = ["ktransformers", "llm"]
|
|||
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12"
|
||||
]
|
||||
|
|
39
setup.py
39
setup.py
|
@ -6,7 +6,7 @@ Author : chenxl
|
|||
Date : 2024-07-27 16:15:27
|
||||
Version : 1.0.0
|
||||
LastEditors : chenxl
|
||||
LastEditTime : 2024-07-29 09:40:24
|
||||
LastEditTime : 2024-07-31 09:44:46
|
||||
Adapted from:
|
||||
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
|
||||
Copyright (c) 2023, Tri Dao.
|
||||
|
@ -19,6 +19,7 @@ import re
|
|||
import ast
|
||||
import subprocess
|
||||
import platform
|
||||
import http.client
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from pathlib import Path
|
||||
|
@ -28,6 +29,15 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
|||
from setuptools import setup, Extension
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
class CpuInstructInfo:
|
||||
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
|
||||
FANCY = "FANCY"
|
||||
AVX512 = "AVX512"
|
||||
AVX2 = "AVX2"
|
||||
CMAKE_NATIVE = "-DLLAMA_NATIVE=ON"
|
||||
CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
|
||||
CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
|
||||
CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
|
||||
|
||||
class VersionInfo:
|
||||
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -61,12 +71,24 @@ class VersionInfo:
|
|||
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||
|
||||
def get_cpu_instruct(self,):
|
||||
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
|
||||
return "fancy"
|
||||
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
|
||||
return "avx512"
|
||||
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:
|
||||
return "avx2"
|
||||
else:
|
||||
print("Using native cpu instruct")
|
||||
if sys.platform.startswith("linux"):
|
||||
with open('/proc/cpuinfo', 'r', encoding="utf-8") as cpu_f:
|
||||
cpuinfo = cpu_f.read()
|
||||
flags_line = [line for line in cpuinfo.split(
|
||||
'\n') if line.startswith('flags')][0]
|
||||
flags = flags_line.split(':')[1].strip().split(' ')
|
||||
# fancy with AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI
|
||||
for flag in flags:
|
||||
if 'avx512bw' in flag:
|
||||
return 'fancy'
|
||||
for flag in flags:
|
||||
if 'avx512' in flag:
|
||||
return 'avx512'
|
||||
|
@ -116,6 +138,7 @@ class BuildWheelsCommand(_bdist_wheel):
|
|||
def run(self):
|
||||
if VersionInfo.FORCE_BUILD:
|
||||
super().run()
|
||||
return
|
||||
wheel_filename, wheel_url = self.get_wheel_name()
|
||||
print("Guessing wheel URL: ", wheel_url)
|
||||
try:
|
||||
|
@ -132,7 +155,7 @@ class BuildWheelsCommand(_bdist_wheel):
|
|||
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
||||
print("Raw wheel path", wheel_path)
|
||||
os.rename(wheel_filename, wheel_path)
|
||||
except (urllib.error.HTTPError, urllib.error.URLError):
|
||||
except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
super().run()
|
||||
|
@ -187,6 +210,18 @@ class CMakeBuild(BuildExtension):
|
|||
cmake_args += [
|
||||
item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
||||
|
||||
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
|
||||
cpu_args = CpuInstructInfo.CMAKE_FANCY
|
||||
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
|
||||
cpu_args = CpuInstructInfo.CMAKE_AVX512
|
||||
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:
|
||||
cpu_args = CpuInstructInfo.CMAKE_AVX2
|
||||
else:
|
||||
cpu_args = CpuInstructInfo.CMAKE_NATIVE
|
||||
|
||||
cmake_args += [
|
||||
item for item in cpu_args.split(" ") if item
|
||||
]
|
||||
# In this example, we pass in the version to C++. You might not need to.
|
||||
cmake_args += [
|
||||
f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue