diff --git a/.github/workflows/kcpp-build-release-linux-rocm.yaml b/.github/workflows/kcpp-build-release-linux-rocm.yaml index 3010a31d1..96b0a641a 100644 --- a/.github/workflows/kcpp-build-release-linux-rocm.yaml +++ b/.github/workflows/kcpp-build-release-linux-rocm.yaml @@ -12,6 +12,7 @@ env: BRANCH_NAME: ${{ github.head_ref || github.ref_name }} KCPP_CUDA: rocm ARCHES_CU12: 1 + NO_WMMA: 1 jobs: linux: diff --git a/Makefile b/Makefile index e762f9327..9c71cb84f 100644 --- a/Makefile +++ b/Makefile @@ -244,7 +244,7 @@ ifdef LLAMA_HIPBLAS ifeq ($(wildcard /opt/rocm),) ROCM_PATH ?= /usr ifdef LLAMA_PORTABLE - GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx942 gfx1010 gfx1030 gfx1031 gfx1032 gfx1100 gfx1101 gfx1102 $(shell $(shell which amdgpu-arch)) + GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx942 gfx1010 gfx1030 gfx1031 gfx1032 gfx1100 gfx1101 gfx1102 gfx1200 gfx1201 $(shell $(shell which amdgpu-arch)) else GPU_TARGETS ?= $(shell $(shell which amdgpu-arch)) endif @@ -252,13 +252,17 @@ endif HCXX := $(ROCM_PATH)/bin/hipcc else ROCM_PATH ?= /opt/rocm - GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx942 gfx1010 gfx1030 gfx1031 gfx1032 gfx1100 gfx1101 gfx1102 $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) + GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx942 gfx1010 gfx1030 gfx1031 gfx1032 gfx1100 gfx1101 gfx1102 gfx1200 gfx1201 $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) HCC := $(ROCM_PATH)/llvm/bin/clang HCXX := $(ROCM_PATH)/llvm/bin/clang++ endif +ifdef LLAMA_NO_WMMA + HIPFLAGS += -DGGML_HIP_NO_ROCWMMA_FATTN +else DETECT_ROCWMMA := $(shell find -L /opt/rocm/include /usr/include -type f -name rocwmma.hpp 2>/dev/null | head -n 1) ifdef DETECT_ROCWMMA HIPFLAGS += -DGGML_HIP_ROCWMMA_FATTN -I$(dir $(DETECT_ROCWMMA)) +endif endif HIPFLAGS += -DGGML_USE_HIP -DGGML_HIP_NO_VMM -DGGML_USE_CUDA -DSD_USE_CUDA $(shell $(ROCM_PATH)/bin/hipconfig -C) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8cf649d31..3031cd201 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -56,7 +56,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if defined(FLASH_ATTN_AVAILABLE) && ((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ == GGML_CUDA_CC_TURING) || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#if !defined(GGML_HIP_NO_ROCWMMA_FATTN) && defined(FLASH_ATTN_AVAILABLE) && ((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ == GGML_CUDA_CC_TURING) || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; diff --git a/koboldcpp.sh b/koboldcpp.sh index 14be57776..508cd0f30 100755 --- a/koboldcpp.sh +++ b/koboldcpp.sh @@ -29,6 +29,7 @@ KCPP_CUDAAPPEND=-cuda${KCPP_CUDA//.}$KCPP_APPEND LLAMA_NOAVX2_FLAG="" ARCHES_FLAG="" +NO_WMMA_FLAG="" if [ -n "$NOAVX2" ]; then LLAMA_NOAVX2_FLAG="LLAMA_NOAVX2=1" fi @@ -38,11 +39,14 @@ fi if [ -n "$ARCHES_CU12" ]; then ARCHES_FLAG="LLAMA_ARCHES_CU12=1" fi +if [ -n "$NO_WMMA" ]; then + NO_WMMA_FLAG="LLAMA_NO_WMMA=1" +fi if [ "$KCPP_CUDA" = "rocm" ]; then - bin/micromamba run -r conda -p conda/envs/linux make -j$(nproc) LLAMA_VULKAN=1 LLAMA_CLBLAST=1 LLAMA_HIPBLAS=1 LLAMA_PORTABLE=1 LLAMA_USE_BUNDLED_GLSLC=1 LLAMA_ADD_CONDA_PATHS=1 $LLAMA_NOAVX2_FLAG $ARCHES_FLAG + bin/micromamba run -r conda -p conda/envs/linux make -j$(nproc) LLAMA_VULKAN=1 LLAMA_CLBLAST=1 LLAMA_HIPBLAS=1 LLAMA_PORTABLE=1 LLAMA_USE_BUNDLED_GLSLC=1 LLAMA_ADD_CONDA_PATHS=1 $LLAMA_NOAVX2_FLAG $ARCHES_FLAG $NO_WMMA_FLAG else - bin/micromamba run -r conda -p conda/envs/linux make -j$(nproc) LLAMA_VULKAN=1 LLAMA_CLBLAST=1 LLAMA_CUBLAS=1 LLAMA_PORTABLE=1 LLAMA_USE_BUNDLED_GLSLC=1 LLAMA_ADD_CONDA_PATHS=1 $LLAMA_NOAVX2_FLAG $ARCHES_FLAG + bin/micromamba run -r conda -p conda/envs/linux make -j$(nproc) LLAMA_VULKAN=1 LLAMA_CLBLAST=1 LLAMA_CUBLAS=1 LLAMA_PORTABLE=1 LLAMA_USE_BUNDLED_GLSLC=1 LLAMA_ADD_CONDA_PATHS=1 $LLAMA_NOAVX2_FLAG $ARCHES_FLAG $NO_WMMA_FLAG fi if [ $? -ne 0 ]; then