reinstated the q4_3 format, for backwards compatibility.

This commit is contained in:
Concedo 2023-04-29 11:42:04 +08:00
parent 0fc1772a8f
commit bb282a4ecf
12 changed files with 364 additions and 26 deletions

View file

@ -24,7 +24,7 @@ static cl_device_id device;
static cl_context context;
static cl_command_queue queue;
static cl_program program;
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2;
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3;
static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
@ -57,6 +57,21 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
return p;
}
static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
if (req_size <= *cur_size) {
return;
}
// Reallocate buffer with enough space
if (*cur_size > 0) {
clReleaseMemObject(*buf);
}
cl_int err;
*buf = clCreateBuffer(context, flags, req_size, NULL, &err);
*cur_size = req_size;
CL_CHECK(err, "clCreateBuffer");
}
void ggml_cl_init(void) {
cl_int err = 0;
char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
@ -97,21 +112,15 @@ void ggml_cl_init(void) {
CL_CHECK(err, "clCreateKernel");
kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
CL_CHECK(err, "clCreateKernel");
}
kernel_q4_3 = clCreateKernel(program, "dequantize_row_q4_3", &err);
CL_CHECK(err, "clCreateKernel");
static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
if (req_size <= *cur_size) {
return;
}
// Reallocate buffer with enough space
if (*cur_size > 0) {
clReleaseMemObject(*buf);
}
cl_int err;
*buf = clCreateBuffer(context, flags, req_size, NULL, &err);
*cur_size = req_size;
CL_CHECK(err, "clCreateBuffer");
//preallocate buffers
const size_t defaultBufSize = 16*1024*1024;
ggml_cl_malloc(defaultBufSize, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a);
ggml_cl_malloc(defaultBufSize, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb);
ggml_cl_malloc(defaultBufSize, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b);
ggml_cl_malloc(defaultBufSize, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c);
}
void ggml_cl_sgemm_wrapper(
@ -148,6 +157,12 @@ void ggml_cl_sgemm_wrapper(
local = 8;
size_qb = global * (sizeof(short) + local) / 16;
break;
case GGML_TYPE_Q4_3:
dequant = true;
kernel = kernel_q4_3;
local = 8;
size_qb = global * (sizeof(short) * 2 + local) / 16;
break;
default:
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
abort();