add op GGML_OP_READ

This commit is contained in:
Zonghang Li 2024-11-26 22:28:34 +04:00
parent f78c437172
commit 3f008f2ad9
4 changed files with 59 additions and 0 deletions

View file

@ -530,6 +530,7 @@ extern "C" {
GGML_OP_OPT_STEP_ADAMW,
GGML_OP_COUNT,
GGML_OP_READ,
};
enum ggml_unary_op {

View file

@ -28,6 +28,7 @@
#include "ggml-cuda/pad.cuh"
#include "ggml-cuda/pool2d.cuh"
#include "ggml-cuda/quantize.cuh"
#include "ggml-cuda/read.cuh"
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/softmax.cuh"
@ -2145,6 +2146,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
}
switch (dst->op) {
case GGML_OP_READ:
ggml_cuda_read(dst);
break;
case GGML_OP_ARGMAX:
ggml_cuda_argmax(ctx, dst);
break;

View file

@ -0,0 +1,49 @@
#include "common.cuh"
#include "read.cuh"
__global__ void read_vram_f32(
const float * data, int64_t ne,
int64_t nb00, int64_t nb01, int64_t nb02, int64_t nb03,
int64_t ne00, int64_t ne01, int64_t ne02
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= ne) return;
int i = idx % ne00;
int j = (idx / ne00) % ne01;
int k = (idx / (ne00 * ne01)) % ne02;
int64_t offset = i * nb00 + j * nb01 + k * nb02;
volatile float value = data[offset / sizeof(float)];
asm volatile("" : : "f"(value) : "memory");
}
void ggml_cuda_read(ggml_tensor * dst) {
const int64_t ne = ggml_nelements(dst);
GGML_ASSERT(ggml_nbytes(dst) <= INT_MAX);
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = dst->ne[0];
const int64_t ne01 = dst->ne[1];
const int64_t ne02 = dst->ne[2];
const int64_t nb00 = dst->nb[0];
const int64_t nb01 = dst->nb[1];
const int64_t nb02 = dst->nb[2];
const int64_t nb03 = dst->nb[3];
const char * dst_ddc = (const char *)dst->data;
cudaStream_t stream;
cudaStreamCreate(&stream);
const int num_blocks = (ne + CUDA_READ_BLOCK_SIZE - 1) / CUDA_READ_BLOCK_SIZE;
read_vram_f32<<<num_blocks, CUDA_READ_BLOCK_SIZE, 0, stream>>>(
(const float *)dst_ddc, ne, nb00, nb01, nb02, nb03, ne00, ne01, ne02
);
cudaStreamSynchronize(stream);
cudaStreamDestroy(stream);
}

View file

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_READ_BLOCK_SIZE 32
void ggml_cuda_read(ggml_tensor * dst);