Wip, CUDA porting malloc improvements, gpu accel for non-llama, backport old quants

This commit is contained in:
Concedo 2023-06-28 18:20:46 +08:00
parent 9527a783ea
commit b4698abafc
10 changed files with 842 additions and 24 deletions

View file

@ -14,7 +14,9 @@
#include <iostream>
#include <algorithm>
#if defined(GGML_USE_CLBLAST)
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
#elif defined(GGML_USE_CLBLAST)
#include "ggml-opencl.h"
#endif
@ -324,7 +326,7 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
fin.close();
//gpu offload
#if defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST) || defined(GGML_USE_CUBLAS)
if(gpulayers>0)
{
const auto & hparams = model.hparams;
@ -337,10 +339,17 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model &
layer.c_attn_proj_w->backend = GGML_BACKEND_GPU;
layer.c_mlp_fc_w->backend = GGML_BACKEND_GPU;
layer.c_mlp_proj_w->backend = GGML_BACKEND_GPU;
#if defined(GGML_USE_CLBLAST)
ggml_cl_transform_tensor(layer.c_attn_attn_w->data,layer.c_attn_attn_w); vram_total += ggml_nbytes(layer.c_attn_attn_w);
ggml_cl_transform_tensor(layer.c_attn_proj_w->data,layer.c_attn_proj_w); vram_total += ggml_nbytes(layer.c_attn_proj_w);
ggml_cl_transform_tensor(layer.c_mlp_fc_w->data,layer.c_mlp_fc_w); vram_total += ggml_nbytes(layer.c_mlp_fc_w);
ggml_cl_transform_tensor(layer.c_mlp_proj_w->data,layer.c_mlp_proj_w); vram_total += ggml_nbytes(layer.c_mlp_proj_w);
#else
ggml_cuda_transform_tensor(layer.c_attn_attn_w->data,layer.c_attn_attn_w); vram_total += ggml_nbytes(layer.c_attn_attn_w);
ggml_cuda_transform_tensor(layer.c_attn_proj_w->data,layer.c_attn_proj_w); vram_total += ggml_nbytes(layer.c_attn_proj_w);
ggml_cuda_transform_tensor(layer.c_mlp_fc_w->data,layer.c_mlp_fc_w); vram_total += ggml_nbytes(layer.c_mlp_fc_w);
ggml_cuda_transform_tensor(layer.c_mlp_proj_w->data,layer.c_mlp_proj_w); vram_total += ggml_nbytes(layer.c_mlp_proj_w);
#endif
}
fprintf(stderr, "%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
}