diff --git a/src/llama.cpp b/src/llama.cpp index 882e90be6..355770826 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12,10 +12,13 @@ #include "ggml-backend.h" #include "ggml-cpp.h" -#if defined(GGML_USE_CLBLAST) +#ifdef GGML_USE_CUDA +# include "ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) # include "ggml-opencl.h" #endif + // TODO: replace with ggml API call #define QK_K 256