[feature] support python 310 and multi instruction

This commit is contained in:
chenxl 2024-07-31 13:58:17 +00:00
parent 25620829ce
commit 112cb3c962
5 changed files with 188 additions and 8 deletions

View file

@ -1 +1 @@
__version__ = "0.1.0" __version__ = "0.1.1"

View file

@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.16)
project(cpuinfer_ext VERSION 0.1.0) project(cpuinfer_ext VERSION 0.1.0)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math")
set(CMAKE_BUILD_TYPE "Release") set(CMAKE_BUILD_TYPE "Release")
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
@ -10,6 +10,27 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(LLAMA_NATIVE "llama: enable -march=native flag" ON) option(LLAMA_NATIVE "llama: enable -march=native flag" ON)
# instruction set specific
if (LLAMA_NATIVE)
set(INS_ENB OFF)
else()
set(INS_ENB ON)
endif()
option(LLAMA_AVX "llama: enable AVX" OFF)
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_FMA "llama: enable FMA" OFF)
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF)
endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
# Architecture specific # Architecture specific
# TODO: probably these flags need to be tweaked on some architectures # TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue # feel free to update the Makefile for your architecture and send a pull request or issue
@ -102,6 +123,20 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>) add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>) add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif() endif()
if (LLAMA_AVX512_FANCY_SIMD)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VL__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VL__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BW__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BW__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512DQ__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512DQ__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
if (LLAMA_AVX512_BF16)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
endif()
elseif (LLAMA_AVX2) elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2) list(APPEND ARCH_FLAGS /arch:AVX2)
elseif (LLAMA_AVX) elseif (LLAMA_AVX)
@ -133,6 +168,15 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
if (LLAMA_AVX512_VNNI) if (LLAMA_AVX512_VNNI)
list(APPEND ARCH_FLAGS -mavx512vnni) list(APPEND ARCH_FLAGS -mavx512vnni)
endif() endif()
if (LLAMA_AVX512_FANCY_SIMD)
list(APPEND ARCH_FLAGS -mavx512vl)
list(APPEND ARCH_FLAGS -mavx512bw)
list(APPEND ARCH_FLAGS -mavx512dq)
list(APPEND ARCH_FLAGS -mavx512vnni)
endif()
if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16)
endif()
endif() endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected") message(STATUS "PowerPC detected")

View file

@ -0,0 +1,100 @@
include(CheckCSourceRuns)
set(AVX_CODE "
#include <immintrin.h>
int main()
{
__m256 a;
a = _mm256_set1_ps(0);
return 0;
}
")
set(AVX512_CODE "
#include <immintrin.h>
int main()
{
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0);
__m512i b = a;
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
return 0;
}
")
set(AVX2_CODE "
#include <immintrin.h>
int main()
{
__m256i a = {0};
a = _mm256_abs_epi16(a);
__m256i x;
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
return 0;
}
")
set(FMA_CODE "
#include <immintrin.h>
int main()
{
__m256 acc = _mm256_setzero_ps();
const __m256 d = _mm256_setzero_ps();
const __m256 p = _mm256_setzero_ps();
acc = _mm256_fmadd_ps( d, p, acc );
return 0;
}
")
macro(check_sse type flags)
set(__FLAG_I 1)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
foreach (__FLAG ${flags})
if (NOT ${type}_FOUND)
set(CMAKE_REQUIRED_FLAGS ${__FLAG})
check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
if (HAS_${type}_${__FLAG_I})
set(${type}_FOUND TRUE CACHE BOOL "${type} support")
set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
endif()
math(EXPR __FLAG_I "${__FLAG_I}+1")
endif()
endforeach()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
if (NOT ${type}_FOUND)
set(${type}_FOUND FALSE CACHE BOOL "${type} support")
set(${type}_FLAGS "" CACHE STRING "${type} flags")
endif()
mark_as_advanced(${type}_FOUND ${type}_FLAGS)
endmacro()
# flags are for MSVC only!
check_sse("AVX" " ;/arch:AVX")
if (NOT ${AVX_FOUND})
set(LLAMA_AVX OFF)
else()
set(LLAMA_AVX ON)
endif()
check_sse("AVX2" " ;/arch:AVX2")
check_sse("FMA" " ;/arch:AVX2")
if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
set(LLAMA_AVX2 OFF)
else()
set(LLAMA_AVX2 ON)
endif()
check_sse("AVX512" " ;/arch:AVX512")
if (NOT ${AVX512_FOUND})
set(LLAMA_AVX512 OFF)
else()
set(LLAMA_AVX512 ON)
endif()

View file

@ -1,7 +1,7 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools", "setuptools",
"torch == 2.3.1", "torch >= 2.3.0",
"ninja", "ninja",
"packaging" "packaging"
] ]
@ -29,7 +29,7 @@ dependencies = [
"fire" "fire"
] ]
requires-python = ">=3.11" requires-python = ">=3.10"
authors = [ authors = [
{name = "KVCache.AI", email = "zhang.mingxing@outlook.com"} {name = "KVCache.AI", email = "zhang.mingxing@outlook.com"}
@ -50,6 +50,7 @@ keywords = ["ktransformers", "llm"]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12" "Programming Language :: Python :: 3.12"
] ]

View file

@ -6,7 +6,7 @@ Author : chenxl
Date : 2024-07-27 16:15:27 Date : 2024-07-27 16:15:27
Version : 1.0.0 Version : 1.0.0
LastEditors : chenxl LastEditors : chenxl
LastEditTime : 2024-07-29 09:40:24 LastEditTime : 2024-07-31 09:44:46
Adapted from: Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao. Copyright (c) 2023, Tri Dao.
@ -19,6 +19,7 @@ import re
import ast import ast
import subprocess import subprocess
import platform import platform
import http.client
import urllib.request import urllib.request
import urllib.error import urllib.error
from pathlib import Path from pathlib import Path
@ -28,7 +29,16 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension from setuptools import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
FANCY = "FANCY"
AVX512 = "AVX512"
AVX2 = "AVX2"
CMAKE_NATIVE = "-DLLAMA_NATIVE=ON"
CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
class VersionInfo: class VersionInfo:
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "ktransformers" PACKAGE_NAME = "ktransformers"
@ -61,12 +71,24 @@ class VersionInfo:
raise ValueError("Unsupported platform: {}".format(sys.platform)) raise ValueError("Unsupported platform: {}".format(sys.platform))
def get_cpu_instruct(self,): def get_cpu_instruct(self,):
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
return "fancy"
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
return "avx512"
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:
return "avx2"
else:
print("Using native cpu instruct")
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
with open('/proc/cpuinfo', 'r', encoding="utf-8") as cpu_f: with open('/proc/cpuinfo', 'r', encoding="utf-8") as cpu_f:
cpuinfo = cpu_f.read() cpuinfo = cpu_f.read()
flags_line = [line for line in cpuinfo.split( flags_line = [line for line in cpuinfo.split(
'\n') if line.startswith('flags')][0] '\n') if line.startswith('flags')][0]
flags = flags_line.split(':')[1].strip().split(' ') flags = flags_line.split(':')[1].strip().split(' ')
# fancy with AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI
for flag in flags:
if 'avx512bw' in flag:
return 'fancy'
for flag in flags: for flag in flags:
if 'avx512' in flag: if 'avx512' in flag:
return 'avx512' return 'avx512'
@ -116,6 +138,7 @@ class BuildWheelsCommand(_bdist_wheel):
def run(self): def run(self):
if VersionInfo.FORCE_BUILD: if VersionInfo.FORCE_BUILD:
super().run() super().run()
return
wheel_filename, wheel_url = self.get_wheel_name() wheel_filename, wheel_url = self.get_wheel_name()
print("Guessing wheel URL: ", wheel_url) print("Guessing wheel URL: ", wheel_url)
try: try:
@ -132,7 +155,7 @@ class BuildWheelsCommand(_bdist_wheel):
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path) print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path) os.rename(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError): except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
print("Precompiled wheel not found. Building from source...") print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source # If the wheel could not be downloaded, build from source
super().run() super().run()
@ -186,7 +209,19 @@ class CMakeBuild(BuildExtension):
if "CMAKE_ARGS" in os.environ: if "CMAKE_ARGS" in os.environ:
cmake_args += [ cmake_args += [
item for item in os.environ["CMAKE_ARGS"].split(" ") if item] item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
cpu_args = CpuInstructInfo.CMAKE_FANCY
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
cpu_args = CpuInstructInfo.CMAKE_AVX512
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:
cpu_args = CpuInstructInfo.CMAKE_AVX2
else:
cpu_args = CpuInstructInfo.CMAKE_NATIVE
cmake_args += [
item for item in cpu_args.split(" ") if item
]
# In this example, we pass in the version to C++. You might not need to. # In this example, we pass in the version to C++. You might not need to.
cmake_args += [ cmake_args += [
f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"]