mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
410 lines
No EOL
16 KiB
CMake
410 lines
No EOL
16 KiB
CMake
cmake_minimum_required(VERSION 3.16)
|
|
project(cpuinfer_ext VERSION 0.1.0)
|
|
|
|
|
|
set(CMAKE_CXX_STANDARD 17)
|
|
|
|
|
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math -fopenmp")
|
|
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
|
|
set(CMAKE_BUILD_TYPE "Release")
|
|
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -ffast-math -fopenmp")
|
|
# set(CMAKE_BUILD_TYPE "Debug")
|
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
|
|
|
|
include(CheckCXXCompilerFlag)
|
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|
|
|
|
|
option(LLAMA_NATIVE "llama: enable -march=native flag" ON)
|
|
|
|
# instruction set specific
|
|
if (LLAMA_NATIVE)
|
|
set(INS_ENB OFF)
|
|
else()
|
|
set(INS_ENB ON)
|
|
endif()
|
|
|
|
option(LLAMA_AVX "llama: enable AVX" OFF)
|
|
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
|
|
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
|
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
|
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
|
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
|
|
option(LLAMA_FMA "llama: enable FMA" OFF)
|
|
# in MSVC F16C is implied with AVX2/AVX512
|
|
if (NOT MSVC)
|
|
option(LLAMA_F16C "llama: enable F16C" OFF)
|
|
endif()
|
|
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
|
|
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON)
|
|
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
|
|
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
|
|
|
|
# Architecture specific
|
|
# TODO: probably these flags need to be tweaked on some architectures
|
|
# feel free to update the Makefile for your architecture and send a pull request or issue
|
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
|
if (MSVC)
|
|
string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
|
|
message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
|
|
else ()
|
|
set(CMAKE_GENERATOR_PLATFORM_LWR "")
|
|
endif ()
|
|
|
|
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
|
|
find_package(Python3 REQUIRED COMPONENTS Interpreter)
|
|
|
|
execute_process(
|
|
COMMAND ${Python3_EXECUTABLE} -c
|
|
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
|
|
OUTPUT_VARIABLE ABI_FLAG
|
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
)
|
|
|
|
set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE)
|
|
endif()
|
|
|
|
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
|
|
|
|
if (NOT MSVC)
|
|
if (LLAMA_STATIC)
|
|
add_link_options(-static)
|
|
if (MINGW)
|
|
add_link_options(-static-libgcc -static-libstdc++)
|
|
endif()
|
|
endif()
|
|
if (LLAMA_GPROF)
|
|
add_compile_options(-pg)
|
|
endif()
|
|
endif()
|
|
|
|
set(ARCH_FLAGS "")
|
|
|
|
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
|
|
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
|
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
|
|
message(STATUS "ARM detected")
|
|
if (MSVC)
|
|
add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead
|
|
add_compile_definitions(__ARM_NEON)
|
|
add_compile_definitions(__ARM_FEATURE_FMA)
|
|
|
|
set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})
|
|
string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2")
|
|
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD)
|
|
if (GGML_COMPILER_SUPPORT_DOTPROD)
|
|
add_compile_definitions(__ARM_FEATURE_DOTPROD)
|
|
endif ()
|
|
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
|
|
if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
|
|
add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
|
|
endif ()
|
|
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
|
|
else()
|
|
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
|
|
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
|
|
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
|
|
endif()
|
|
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
|
|
# Raspberry Pi 1, Zero
|
|
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)
|
|
endif()
|
|
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
|
|
if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
|
|
# Android armeabi-v7a
|
|
list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)
|
|
else()
|
|
# Raspberry Pi 2
|
|
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
|
|
endif()
|
|
endif()
|
|
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
|
|
# Android arm64-v8a
|
|
# Raspberry Pi 3, 4, Zero 2 (32-bit)
|
|
list(APPEND ARCH_FLAGS -mno-unaligned-access)
|
|
endif()
|
|
endif()
|
|
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
|
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
|
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
|
|
message(STATUS "x86 detected")
|
|
set(HOST_IS_X86 TRUE)
|
|
set(HAS_AVX512 TRUE)
|
|
set(__HAS_AMX__ TRUE)
|
|
add_compile_definitions(__x86_64__)
|
|
# check AVX512
|
|
execute_process(
|
|
COMMAND lscpu
|
|
OUTPUT_VARIABLE LSCPU_OUTPUT
|
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
)
|
|
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
|
|
|
|
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
|
|
|
|
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
|
|
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
|
|
add_compile_definitions(__HAS_AVX512F__)
|
|
else()
|
|
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
|
|
set(HAS_AVX512 False)
|
|
endif()
|
|
|
|
# check AMX
|
|
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
|
|
|
|
if(COMPILER_SUPPORTS_AMX GREATER -1)
|
|
message(STATUS "Compiler supports AMX")
|
|
add_compile_definitions(__HAS_AMX__)
|
|
else()
|
|
message(STATUS "Compiler does NOT support AMX")
|
|
endif()
|
|
if (MSVC)
|
|
# instruction set detection for MSVC only
|
|
if (LLAMA_NATIVE)
|
|
include(cmake/FindSIMD.cmake)
|
|
endif ()
|
|
if (LLAMA_AVX512)
|
|
list(APPEND ARCH_FLAGS /arch:AVX512)
|
|
# MSVC has no compile-time flags enabling specific
|
|
# AVX512 extensions, neither it defines the
|
|
# macros corresponding to the extensions.
|
|
# Do it manually.
|
|
if (LLAMA_AVX512_VBMI)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
|
|
endif()
|
|
if (LLAMA_AVX512_VNNI)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
|
endif()
|
|
if (LLAMA_AVX512_FANCY_SIMD)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VL__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VL__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BW__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BW__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512DQ__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512DQ__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
|
endif()
|
|
if (LLAMA_AVX512_BF16)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
|
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
|
|
endif()
|
|
elseif (LLAMA_AVX2)
|
|
list(APPEND ARCH_FLAGS /arch:AVX2)
|
|
elseif (LLAMA_AVX)
|
|
list(APPEND ARCH_FLAGS /arch:AVX)
|
|
endif()
|
|
else()
|
|
if (LLAMA_NATIVE)
|
|
list(APPEND ARCH_FLAGS -mfma -mavx -mavx2)
|
|
list(APPEND ARCH_FLAGS -march=native)
|
|
endif()
|
|
if (LLAMA_F16C)
|
|
list(APPEND ARCH_FLAGS -mf16c)
|
|
endif()
|
|
if (LLAMA_FMA)
|
|
list(APPEND ARCH_FLAGS -mfma)
|
|
endif()
|
|
if (LLAMA_AVX)
|
|
list(APPEND ARCH_FLAGS -mavx)
|
|
endif()
|
|
if (LLAMA_AVX2)
|
|
list(APPEND ARCH_FLAGS -mavx2)
|
|
endif()
|
|
if (LLAMA_AVX512)
|
|
list(APPEND ARCH_FLAGS -mavx512f)
|
|
list(APPEND ARCH_FLAGS -mavx512bw)
|
|
endif()
|
|
if (LLAMA_AVX512_VBMI)
|
|
list(APPEND ARCH_FLAGS -mavx512vbmi)
|
|
endif()
|
|
if (LLAMA_AVX512_VNNI)
|
|
list(APPEND ARCH_FLAGS -mavx512vnni)
|
|
endif()
|
|
if (LLAMA_AVX512_FANCY_SIMD)
|
|
message(STATUS "AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI enabled")
|
|
list(APPEND ARCH_FLAGS -mavx512vl)
|
|
list(APPEND ARCH_FLAGS -mavx512bw)
|
|
list(APPEND ARCH_FLAGS -mavx512dq)
|
|
list(APPEND ARCH_FLAGS -mavx512vnni)
|
|
list(APPEND ARCH_FLAGS -mavx512vpopcntdq)
|
|
endif()
|
|
if (LLAMA_AVX512_BF16)
|
|
list(APPEND ARCH_FLAGS -mavx512bf16)
|
|
endif()
|
|
endif()
|
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
|
message(STATUS "PowerPC detected")
|
|
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
|
list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
|
|
else()
|
|
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
|
|
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
|
|
endif()
|
|
else()
|
|
message(STATUS "Unknown architecture")
|
|
endif()
|
|
|
|
# message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}")
|
|
# find_package(FindCUDAToolkit REQUIRED)
|
|
# if(CUDAToolkit_FOUND)
|
|
# message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}")
|
|
# else()
|
|
# message(STATUS "Can't found CUDA lib")
|
|
# endif()
|
|
|
|
if (NOT EXISTS $ENV{ROCM_PATH})
|
|
if (NOT EXISTS /opt/rocm)
|
|
set(ROCM_PATH /usr)
|
|
else()
|
|
set(ROCM_PATH /opt/rocm)
|
|
endif()
|
|
else()
|
|
set(ROCM_PATH $ENV{ROCM_PATH})
|
|
endif()
|
|
|
|
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
|
|
list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
|
|
|
|
if (NOT EXISTS $ENV{MUSA_PATH})
|
|
if (NOT EXISTS /opt/musa)
|
|
set(MUSA_PATH /usr/local/musa)
|
|
else()
|
|
set(MUSA_PATH /opt/musa)
|
|
endif()
|
|
else()
|
|
set(MUSA_PATH $ENV{MUSA_PATH})
|
|
endif()
|
|
|
|
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
|
|
|
|
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
|
|
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
|
|
|
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11)
|
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
|
|
|
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
|
|
if (WIN32)
|
|
include_directories("$ENV{CUDA_PATH}/include")
|
|
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
|
elseif (UNIX)
|
|
if (KTRANSFORMERS_USE_ROCM)
|
|
find_package(HIP REQUIRED)
|
|
if(HIP_FOUND)
|
|
include_directories("${HIP_INCLUDE_DIRS}")
|
|
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
|
|
endif()
|
|
elseif (KTRANSFORMERS_USE_MUSA)
|
|
if (NOT EXISTS $ENV{MUSA_PATH})
|
|
if (NOT EXISTS /opt/musa)
|
|
set(MUSA_PATH /usr/local/musa)
|
|
else()
|
|
set(MUSA_PATH /opt/musa)
|
|
endif()
|
|
else()
|
|
set(MUSA_PATH $ENV{MUSA_PATH})
|
|
endif()
|
|
|
|
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
|
|
|
|
find_package(MUSAToolkit)
|
|
if (MUSAToolkit_FOUND)
|
|
message(STATUS "MUSA Toolkit found")
|
|
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
|
|
endif()
|
|
else()
|
|
find_package(CUDA REQUIRED)
|
|
include_directories("${CUDA_INCLUDE_DIRS}")
|
|
include(CheckLanguage)
|
|
check_language(CUDA)
|
|
if(CMAKE_CUDA_COMPILER)
|
|
message(STATUS "CUDA detected")
|
|
find_package(CUDAToolkit REQUIRED)
|
|
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
|
endif()
|
|
message(STATUS "enabling CUDA")
|
|
enable_language(CUDA)
|
|
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
|
|
endif()
|
|
endif()
|
|
|
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
|
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
|
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)
|
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
|
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
|
|
|
|
if (HOST_IS_X86 AND HAS_AVX512 AND __HAS_AMX__)
|
|
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
|
|
endif()
|
|
|
|
|
|
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
|
|
|
|
file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h")
|
|
|
|
add_custom_target(
|
|
format
|
|
COMMAND clang-format
|
|
-i
|
|
-style=file
|
|
${FMT_SOURCES}
|
|
COMMENT "Running clang-format on all source files"
|
|
)
|
|
|
|
|
|
add_library(llamafile STATIC ${SOURCE_DIR4})
|
|
|
|
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
|
message(STATUS "ARCH_FLAGS: ${ARCH_FLAGS}")
|
|
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
|
|
target_link_libraries(${PROJECT_NAME} PRIVATE llama)
|
|
|
|
|
|
if(WIN32)
|
|
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
|
|
elseif(UNIX)
|
|
if (KTRANSFORMERS_USE_ROCM)
|
|
add_compile_definitions(USE_HIP=1)
|
|
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
|
|
message(STATUS "Building for HIP")
|
|
elseif(KTRANSFORMERS_USE_MUSA)
|
|
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
|
|
else()
|
|
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
|
|
endif()
|
|
endif()
|
|
|
|
# Define the USE_NUMA option
|
|
option(USE_NUMA "Disable NUMA support" OFF)
|
|
|
|
# Check if the USE_NUMA environment variable is set
|
|
if(DEFINED ENV{USE_NUMA})
|
|
set(USE_NUMA ON)
|
|
endif()
|
|
|
|
if(USE_NUMA)
|
|
message(STATUS "NUMA support is enabled")
|
|
else()
|
|
message(STATUS "NUMA support is disabled")
|
|
endif()
|
|
|
|
find_library(NUMA_LIBRARY NAMES numa)
|
|
|
|
if(NUMA_LIBRARY AND USE_NUMA)
|
|
message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support")
|
|
target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})
|
|
target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA)
|
|
else()
|
|
if(USE_NUMA)
|
|
message(FATAL_ERROR "NUMA library not found - maybe sudo apt install libnuma-dev")
|
|
else()
|
|
message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support")
|
|
endif()
|
|
endif() |