mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
CUDA: add fused rms norm (#14800)
This commit is contained in:
parent
acd6cb1c41
commit
8c988fa41d
4 changed files with 144 additions and 9 deletions
|
@ -2641,6 +2641,7 @@ struct test_rms_norm_mul_add : public test_case {
|
|||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
const bool broadcast;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
|
@ -2650,18 +2651,21 @@ struct test_rms_norm_mul_add : public test_case {
|
|||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne, eps);
|
||||
return VARS_TO_STR4(type, ne, eps, broadcast);
|
||||
}
|
||||
|
||||
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||
float eps = 1e-6f)
|
||||
: type(type), ne(ne), eps(eps) {}
|
||||
float eps = 1e-6f, bool broadcast = false)
|
||||
: type(type), ne(ne), eps(eps), broadcast(broadcast) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
|
||||
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_param(b);
|
||||
|
@ -5354,6 +5358,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue