From 9439da6f9521ee67a3c802bcedac8dffdf589b83 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 29 Apr 2023 07:43:15 +0200 Subject: [PATCH] Implement q5_0, q5_1 and q8_0 --- ggml-opencl-dequant.cl | 71 +++++++++++++++++++++++++++++++++++++++++- ggml-opencl.c | 58 ++++++++++++++++++++++++++-------- 2 files changed, 115 insertions(+), 14 deletions(-) diff --git a/ggml-opencl-dequant.cl b/ggml-opencl-dequant.cl index a65a79f4d..a5cac1b6a 100644 --- a/ggml-opencl-dequant.cl +++ b/ggml-opencl-dequant.cl @@ -51,7 +51,7 @@ __kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global f const uint i = get_global_id(0) / 16; const uint l = get_local_id(0); - const float d = vload_half(0, (__global half*) &blocks[i].d);; + const float d = vload_half(0, (__global half*) &blocks[i].d); const uchar vi = blocks[i].qs[l]; @@ -60,4 +60,73 @@ __kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global f result[index + 1] = ((vi >> 4) - 8)*d; } + +struct block_q5_0 +{ + ushort d; + uint qh; + uchar qs[16]; +}; + +__kernel void dequantize_row_q5_0(__global struct block_q5_0* blocks, __global float* result) { + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); + + const float d = vload_half(0, (__global half*) &blocks[i].d); + + const uchar vi = blocks[i].qs[l]; + + const uint l2 = l * 2; + + const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4; + const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4; + + const uint index = i*32 + l2; + result[index + 0] = (((vi & 0xf) | vh0) - 16)*d; + result[index + 1] = (((vi >> 4) | vh1) - 16)*d; +} + +struct block_q5_1 +{ + ushort d; + ushort m; + uint qh; + uchar qs[16]; +}; + +__kernel void dequantize_row_q5_1(__global struct block_q5_1* blocks, __global float* result) { + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); + + const float d = vload_half(0, (__global half*) &blocks[i].d); + const float m = vload_half(0, (__global half*) &blocks[i].m); + + const uchar vi = blocks[i].qs[l]; + + const uint l2 = l * 2; + + const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4; + const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4; + + const uint index = i*32 + l2; + result[index + 0] = ((vi & 0xf) | vh0)*d + m; + result[index + 1] = ((vi >> 4) | vh1)*d + m; +} + +struct block_q8_0 +{ + float d; + uchar qs[32]; +}; + +__kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global float* result) { + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); + + const float d = blocks[i].d; + + const uint index = i*32 + l; + result[index] = blocks[i].qs[l] * d; +} + ); diff --git a/ggml-opencl.c b/ggml-opencl.c index b748f86b7..026f206ce 100644 --- a/ggml-opencl.c +++ b/ggml-opencl.c @@ -24,7 +24,7 @@ static cl_device_id device; static cl_context context; static cl_command_queue queue; static cl_program program; -static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2; +static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q5_0, kernel_q5_1, kernel_q8_0; static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0; @@ -97,6 +97,12 @@ void ggml_cl_init(void) { CL_CHECK(err, "clCreateKernel"); kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err); CL_CHECK(err, "clCreateKernel"); + kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err); + CL_CHECK(err, "clCreateKernel"); + kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err); + CL_CHECK(err, "clCreateKernel"); + kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err); + CL_CHECK(err, "clCreateKernel"); } static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) { @@ -148,6 +154,24 @@ void ggml_cl_sgemm_wrapper( local = 8; size_qb = global * (sizeof(short) + local) / 16; break; + case GGML_TYPE_Q5_0: + dequant = true; + kernel = kernel_q5_0; + local = 16; + size_qb = global * (sizeof(short) + 4 + local) / 32; + break; + case GGML_TYPE_Q5_1: + dequant = true; + kernel = kernel_q5_1; + local = 16; + size_qb = global * (sizeof(short) * 2 + 4 + local) / 32; + break; + case GGML_TYPE_Q8_0: + dequant = true; + kernel = kernel_q8_0; + local = 32; + size_qb = global * (sizeof(float) + local) / 32; + break; default: fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype); abort(); @@ -171,12 +195,15 @@ void ggml_cl_sgemm_wrapper( err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); CL_CHECK(err, "clSetKernelArg"); - clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb); + err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb); + CL_CHECK(err, "clEnqueueWriteBuffer qb"); } else { - clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b); + err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b); + CL_CHECK(err, "clEnqueueWriteBuffer b"); } - clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a); + err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a); + CL_CHECK(err, "clEnqueueWriteBuffer a"); if (dequant) { err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b); CL_CHECK(err, "clEnqueueNDRangeKernel"); @@ -188,15 +215,20 @@ void ggml_cl_sgemm_wrapper( clReleaseEvent(ev_b); cl_event ev_sgemm; - CLBlastSgemm((CLBlastLayout)order, - (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b, - m, n, k, - alpha, - cl_buffer_a, 0, lda, - cl_buffer_b, 0, ldb, - beta, - cl_buffer_c, 0, ldc, - &queue, &ev_sgemm); + CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order, + (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, &ev_sgemm); + + if (status != CLBlastSuccess) { + fprintf(stderr, "Error: CLBlast SGEMM %d\n", status); + abort(); + } cl_event ev_c; clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);