mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
CUDA: fuse relu + sqr (#22249)
This commit is contained in:
parent
6217b49583
commit
86db42e97f
4 changed files with 95 additions and 0 deletions
|
|
@ -3592,6 +3592,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) {
|
||||
const ggml_tensor * unary = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * sqr = cgraph->nodes[node_idx+1];
|
||||
|
||||
if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unary->type != sqr->type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(unary->src[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
|
||||
const ggml_tensor *scale = cgraph->nodes[node_idx];
|
||||
|
|
@ -4100,6 +4124,12 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
|
||||
ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
||||
i += 2;
|
||||
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
|
||||
|
|
|
|||
|
|
@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) {
|
|||
return x * x;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_relu_sqr(float x) {
|
||||
const float r = fmaxf(x, 0.0f);
|
||||
return r * r;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_sqrt(float x) {
|
||||
return sqrtf(x);
|
||||
}
|
||||
|
|
@ -615,3 +620,21 @@ void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary
|
|||
GGML_ABORT("Unsupported unary op for fused unary+mul");
|
||||
}
|
||||
}
|
||||
|
||||
/* fused relu + sqr */
|
||||
|
||||
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) {
|
||||
const ggml_tensor * src = relu_node->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src));
|
||||
GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src->type == sqr_node->type);
|
||||
|
||||
const int k = ggml_nelements(src);
|
||||
if (src->type == GGML_TYPE_F16) {
|
||||
unary_cuda<op_relu_sqr>((const half *)src->data, (half *)sqr_node->data, k, stream);
|
||||
} else {
|
||||
unary_cuda<op_relu_sqr>((const float *)src->data, (float *)sqr_node->data, k, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||
|
||||
void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node);
|
||||
|
||||
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node);
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3522,6 +3522,40 @@ struct test_add_rms_norm : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_UNARY(RELU) + GGML_OP_SQR (fused operation)
|
||||
struct test_relu_sqr : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "RELU_SQR";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_relu_sqr(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {128, 2, 2, 2})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * r = ggml_relu(ctx, a);
|
||||
ggml_set_name(r, "relu");
|
||||
|
||||
ggml_tensor * out = ggml_sqr(ctx, r);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SSM_CONV
|
||||
struct test_ssm_conv : public test_case {
|
||||
const ggml_type type;
|
||||
|
|
@ -7311,6 +7345,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
|
||||
// fused relu + sqr (squared ReLU)
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_relu_sqr(type, { 128, 2, 2, 2 }));
|
||||
test_cases.emplace_back(new test_relu_sqr(type, { 5, 7, 11, 13 }));
|
||||
}
|
||||
|
||||
// glu ops
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
for (int v : {0, 1}) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue