diff --git a/common/arg.cpp b/common/arg.cpp index 862755bb5..5d1393c79 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2572,5 +2572,43 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--fim-qwen-7b-spec"}, + string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), + [](common_params & params) { + params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; + params.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; + params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.n_gpu_layers = 99; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--fim-qwen-14b-spec"}, + string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), + [](common_params & params) { + params.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; + params.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; + params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.n_gpu_layers = 99; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + return ctx_arg; } diff --git a/common/minja/minja.hpp b/common/minja/minja.hpp index c58dd66e0..fa4c34d6e 100644 --- a/common/minja/minja.hpp +++ b/common/minja/minja.hpp @@ -1378,13 +1378,27 @@ struct ArgumentsExpression { } }; -static std::string strip(const std::string & s) { - auto start = s.find_first_not_of(" \t\n\r"); +static std::string strip(const std::string & s, const std::string & chars = "", bool left = true, bool right = true) { + auto charset = chars.empty() ? " \t\n\r" : chars; + auto start = left ? s.find_first_not_of(charset) : 0; if (start == std::string::npos) return ""; - auto end = s.find_last_not_of(" \t\n\r"); + auto end = right ? s.find_last_not_of(charset) : s.size() - 1; return s.substr(start, end - start + 1); } +static std::vector split(const std::string & s, const std::string & sep) { + std::vector result; + size_t start = 0; + size_t end = s.find(sep); + while (end != std::string::npos) { + result.push_back(s.substr(start, end - start)); + start = end + sep.length(); + end = s.find(sep, start); + } + result.push_back(s.substr(start)); + return result; +} + static std::string capitalize(const std::string & s) { if (s.empty()) return s; auto result = s; @@ -1467,8 +1481,26 @@ public: } else if (obj.is_string()) { auto str = obj.get(); if (method->get_name() == "strip") { - vargs.expectArgs("strip method", {0, 0}, {0, 0}); - return Value(strip(str)); + vargs.expectArgs("strip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars)); + } else if (method->get_name() == "lstrip") { + vargs.expectArgs("lstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ true, /* right= */ false)); + } else if (method->get_name() == "rstrip") { + vargs.expectArgs("rstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ false, /* right= */ true)); + } else if (method->get_name() == "split") { + vargs.expectArgs("split method", {1, 1}, {0, 0}); + auto sep = vargs.args[0].get(); + auto parts = split(str, sep); + Value result = Value::array(); + for (const auto& part : parts) { + result.push_back(Value(part)); + } + return result; } else if (method->get_name() == "capitalize") { vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); return Value(capitalize(str)); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2a526b0e7..8386f4eeb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1312,7 +1312,7 @@ struct server_slot { return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; } - bool can_batch_with(server_slot & other_slot) { + bool can_batch_with(server_slot & other_slot) const { return is_non_causal() == other_slot.is_non_causal() && are_lora_equal(lora, other_slot.lora); } @@ -1900,6 +1900,7 @@ struct server_context { try { common_chat_format_example(chat_templates.get(), params.use_jinja); } catch (const std::exception & e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); chat_templates = common_chat_templates_init(model, "chatml"); } @@ -2156,14 +2157,6 @@ struct server_context { } if (slot.has_new_line) { - // if we have already seen a new line, we stop after a certain time limit - if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); - } - // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent if (slot.params.n_indent > 0) { // check the current indentation @@ -2202,6 +2195,14 @@ struct server_context { // check if there is a new line in the generated text if (result.text_to_send.find('\n') != std::string::npos) { slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + } } // if context shift is disabled, we stop when it reaches the context limit diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 0a3e96463..8b959e355 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -76,7 +76,14 @@ namespace fs = std::filesystem; static std::string path_str(const fs::path & path) { std::string u8path; try { +#if defined(__cpp_lib_char8_t) + // C++20 and later: u8string() returns std::u8string + std::u8string u8str = path.u8string(); + u8path = std::string(reinterpret_cast(u8str.c_str())); +#else + // C++17: u8string() returns std::string u8path = path.u8string(); +#endif } catch (...) { } return u8path; diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 4d0bc72fe..dbeeeee25 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -11719,9 +11719,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const #elif defined __AVX2__ - const __m256i mask = _mm256_set1_epi16(2 * 0x7); + const __m256i mask = _mm256_set1_epi16(0x7); const __m256i mone = _mm256_set1_epi16(1); const __m256i mone8 = _mm256_set1_epi8(1); + const __m256i mtwo8 = _mm256_set1_epi8(2); + // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half. + const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0); __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); @@ -11733,6 +11736,14 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const const uint16_t * sc = (const uint16_t *)x[i].scales; scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + // Extract 3-bit scales (16 values) + __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc); + scales = _mm256_srlv_epi64(scales, scales_shift); + scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone); + + // Indices to repeat each scale 8 times. + __m256i scales_idx1 = _mm256_set1_epi16(0x0100); + __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8)); __m256i sumi1 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256(); @@ -11778,11 +11789,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1)); const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2)); - __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 2), _mm_set1_epi16(sc[ib/2] << 1)); - __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 8), _mm_set1_epi16(sc[ib/2] >> 5)); + __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1); + __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2); + + scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8); + scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8); - scale1 = _mm256_add_epi16(_mm256_and_si256(scale1, mask), mone); - scale2 = _mm256_add_epi16(_mm256_and_si256(scale2, mask), mone); const __m256i p1 = _mm256_madd_epi16(dot1, scale1); const __m256i p2 = _mm256_madd_epi16(dot2, scale2); const __m256i p3 = _mm256_madd_epi16(dot3, scale1); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index acda631e9..8c5e807a3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -6678,6 +6678,135 @@ static void ggml_compute_forward_repeat_back( // ggml_compute_forward_concat +static void ggml_compute_forward_concat_any( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + const size_t len = ggml_type_size(src0->type); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const char * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03; + } else { + x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13; + } + + char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3; + + memcpy(y, x, len); + } + } + } + } +} + +static void ggml_compute_forward_concat_i8( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const int8_t * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } + + int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + +static void ggml_compute_forward_concat_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const ggml_fp16_t * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } + + ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + static void ggml_compute_forward_concat_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6685,7 +6814,7 @@ static void ggml_compute_forward_concat_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -6728,6 +6857,16 @@ static void ggml_compute_forward_concat( const struct ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_I16: + { + ggml_compute_forward_concat_f16(params, dst); + } break; + case GGML_TYPE_I8: + { + ggml_compute_forward_concat_i8(params, dst); + } break; case GGML_TYPE_F32: case GGML_TYPE_I32: { @@ -6735,7 +6874,7 @@ static void ggml_compute_forward_concat( } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_concat_any(params, dst); } } } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 24f973056..2e72fc8fd 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: - if (cc == GGML_CUDA_CC_VOLTA) { + if (fp16_mma_available(cc) && !new_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index e3dc25f16..a58c474eb 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -285,4 +285,239 @@ typedef struct { float eps; } ggml_metal_kargs_rms_norm; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t n_groups; + float eps; +} ggml_metal_kargs_group_norm; + +typedef struct { + int32_t IC; + int32_t IL; + int32_t K; + int32_t s0; + uint64_t nb0; + uint64_t nb1; +} ggml_metal_kargs_conv_transpose_1d; + +typedef struct { + uint64_t ofs0; + uint64_t ofs1; + int32_t IW; + int32_t IH; + int32_t CHW; + int32_t s0; + int32_t s1; + int32_t p0; + int32_t p1; + int32_t d0; + int32_t d1; + int32_t N; + int32_t KH; + int32_t KW; + int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources +} ggml_metal_kargs_im2col; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne10; + int64_t ne11; + int64_t ne12; + int64_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_sum_rows; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; +} ggml_metal_kargs_soft_max; + +typedef struct { + int64_t ne00; + int64_t ne01; + int n_past; +} ggml_metal_kargs_diag_mask_inf; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int64_t ne10; + int64_t ne11; + uint64_t nb10; + uint64_t nb11; + int64_t ne0; + int64_t ne1; + int64_t ne2; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_ssm_conv; + +typedef struct { + int64_t d_state; + int64_t d_inner; + int64_t n_seq_tokens; + int64_t n_seqs; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb30; + uint64_t nb31; + uint64_t nb40; + uint64_t nb41; + uint64_t nb42; + uint64_t nb50; + uint64_t nb51; + uint64_t nb52; +} ggml_metal_kargs_ssm_scan; + +typedef struct { + int64_t ne00; + uint64_t nb01; + uint64_t nb02; + int64_t ne10; + uint64_t nb10; + uint64_t nb11; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_get_rows; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float sf0; + float sf1; + float sf2; + float sf3; +} ggml_metal_kargs_upscale; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_pad; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t p0; + int32_t p1; +} ggml_metal_kargs_pad_reflect_1d; + +typedef struct { + uint64_t nb1; + int dim; + int max_period; +} ggml_metal_kargs_timestep_embedding; + +typedef struct { + float slope; +} ggml_metal_kargs_leaky_relu; + +typedef struct { + int64_t ncols; + int64_t ncols_pad; +} ggml_metal_kargs_argsort; + +typedef struct { + int64_t ne0; + float start; + float step; +} ggml_metal_kargs_arange; + +typedef struct { + int32_t k0; + int32_t k1; + int32_t s0; + int32_t s1; + int32_t p0; + int32_t p1; + int64_t IH; + int64_t IW; + int64_t OH; + int64_t OW; + int64_t parallel_elements; +} ggml_metal_kargs_pool_2d; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d1ac28f46..22babf9c2 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1945,34 +1945,38 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - // TODO: add ggml_metal_kargs struct + + ggml_metal_kargs_sum_rows args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -2021,8 +2025,17 @@ static void ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - // TODO: add ggml_metal_kargs struct - // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) + ggml_metal_kargs_soft_max args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { @@ -2031,14 +2044,7 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; } [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; @@ -2056,13 +2062,16 @@ static void ggml_metal_encode_node( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_diag_mask_inf args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.n_past =*/ n_past, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; if (ne00%8 == 0) { [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; @@ -2081,27 +2090,30 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_ssm_conv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -2152,7 +2164,31 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_ssm_scan args = { + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ n_seqs, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb30 =*/ nb30, + /*.nb31 =*/ nb31, + /*.nb40 =*/ nb40, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb50 =*/ nb50, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2161,30 +2197,7 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + [encoder setBytes:&args length:sizeof(args) atIndex:7]; [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -3041,19 +3054,22 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_get_rows args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; @@ -3110,18 +3126,21 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_group_norm args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.n_groups =*/ n_groups, + /*.eps =*/ eps, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -3279,8 +3298,8 @@ static void ggml_metal_encode_node( const int32_t CHW = IC * KH * KW; - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; @@ -3302,27 +3321,30 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_im2col args = { + /*.ofs0 =*/ ofs0, + /*.ofs1 =*/ ofs1, + /*.IW =*/ IW, + /*.IH =*/ IH, + /*.CHW =*/ CHW, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.p0 =*/ p0, + /*.p1 =*/ p1, + /*.d0 =*/ d0, + /*.d1 =*/ d1, + /*.N =*/ N, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.KHW =*/ KH * KW, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; if (is_gt_mttpt) { - [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; - [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14]; - [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15]; - const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); @@ -3362,16 +3384,20 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + ggml_metal_kargs_conv_transpose_1d args = { + /*.IC =*/ IC, + /*.IL =*/ IL, + /*.K =*/ K, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&K length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -3386,30 +3412,33 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_upscale args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3 + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; - [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; - [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; - [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); @@ -3421,26 +3450,29 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_pad args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3 + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int nth = MIN(1024, ne0); @@ -3455,24 +3487,31 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline; + ggml_metal_kargs_pad_reflect_1d args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.p0 =*/ p0, + /*.p1 =*/ p1 + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14]; - [encoder setBytes:&p0 length:sizeof(p0) atIndex:15]; - [encoder setBytes:&p1 length:sizeof(p1) atIndex:16]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int nth = MIN(1024, ne0); @@ -3490,12 +3529,15 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_arange args = { + /*.ne0 =*/ ne0, + /*.start =*/ start, + /*.step =*/ step + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; - [encoder setBytes:&start length:sizeof(start) atIndex:2]; - [encoder setBytes:&step length:sizeof(step) atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&args length:sizeof(args) atIndex:1]; const int nth = MIN(1024, ne0); @@ -3512,13 +3554,16 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_timestep_embedding args = { + /*.nb1 =*/ nb1, + /*.dim =*/ dim, + /*.max_period =*/ max_period + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; - [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int nth = MIN(1024, half); @@ -3551,12 +3596,15 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_argsort args = { + /*.ncols =*/ ne00, + /*.ncols_pad =*/ ne00_padded + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; @@ -3570,11 +3618,14 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_leaky_relu args = { + /*.slope =*/ slope + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int64_t n = ggml_nelements(dst); @@ -4150,21 +4201,24 @@ static void ggml_metal_encode_node( const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_pool_2d args_pool_2d = { + /* .k0 = */ k0, + /* .k1 = */ k1, + /* .s0 = */ s0, + /* .s1 = */ s1, + /* .p0 = */ p0, + /* .p1 = */ p1, + /* .IH = */ IH, + /* .IW = */ IW, + /* .OH = */ OH, + /* .OW = */ OW, + /* .parallel_elements = */ parallel_elements + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; - [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3]; - [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4]; - [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5]; - [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6]; - [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7]; - [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8]; - [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9]; - [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10]; - [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11]; - [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index d092a1690..ad9d42a3e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3,8 +3,7 @@ #if defined(GGML_METAL_EMBED_LIBRARY) __embed_ggml-common.h__ #else -// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift -#include "../ggml-common.h" +#include "ggml-common.h" #endif #include "ggml-metal-impl.h" @@ -948,45 +947,22 @@ kernel void kernel_cos( kernel void kernel_sum_rows( device const float * src0, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, + constant ggml_metal_kargs_sum_rows & args, uint3 tpig[[thread_position_in_grid]]) { int64_t i3 = tpig.z; int64_t i2 = tpig.y; int64_t i1 = tpig.x; - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { return; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); float row_sum = 0; - for (int64_t i0 = 0; i0 < ne00; i0++) { + for (int64_t i0 = 0; i0 < args.ne00; i0++) { row_sum += src_row[i0]; } @@ -998,36 +974,29 @@ kernel void kernel_soft_max( device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, + constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + const int64_t i03 = (tgpig) / (args.ne02*args.ne01); + const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; + const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); - device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); float slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const int64_t h = i02; - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exp); } @@ -1035,8 +1004,8 @@ kernel void kernel_soft_max( // parallel max float lmax = -INFINITY; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block @@ -1060,8 +1029,8 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -1091,7 +1060,7 @@ kernel void kernel_soft_max( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { pdst[i00] *= inv_sum; } } @@ -1101,35 +1070,28 @@ kernel void kernel_soft_max_4( device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, + constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + const int64_t i03 = (tgpig) / (args.ne02*args.ne01); + const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; + const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); - device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; float slope = 1.0f; - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const int64_t h = i02; - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exp); } @@ -1137,8 +1099,8 @@ kernel void kernel_soft_max_4( // parallel max float4 lmax4 = -INFINITY; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -1163,8 +1125,8 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -1196,7 +1158,7 @@ kernel void kernel_soft_max_4( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { pdst4[i00] *= inv_sum; } } @@ -1212,27 +1174,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, + constant ggml_metal_kargs_diag_mask_inf & args, uint3 tpig[[thread_position_in_grid]]) { const int64_t i02 = tpig[2]; const int64_t i01 = tpig[1]; const int64_t i00 = tpig[0]; - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + if (i00 > args.n_past + i01) { + dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY; } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00]; } } kernel void kernel_diag_mask_inf_8( device const float4 * src0, device float4 * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, + constant ggml_metal_kargs_diag_mask_inf & args, uint3 tpig[[thread_position_in_grid]]) { const int64_t i = 2*tpig[0]; @@ -1240,42 +1198,26 @@ kernel void kernel_diag_mask_inf_8( dst[i+0] = src0[i+0]; dst[i+1] = src0[i+1]; int64_t i4 = 4*i; - const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01; + const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00; const int64_t i00 = i4; for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { + if (i00 + 4 + k <= args.n_past + i01) { break; } dst[i+1][k] = -INFINITY; - if (i00 + k > n_past + i01) { + if (i00 + k > args.n_past + i01) { dst[i][k] = -INFINITY; } } } // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 -// TODO: optimize kernel void kernel_ssm_conv_f32( device const void * src0, device const void * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, + constant ggml_metal_kargs_ssm_conv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -1283,15 +1225,15 @@ kernel void kernel_ssm_conv_f32( const int64_t i2 = tgpig.y; const int64_t i3 = tgpig.z; - const int64_t nc = ne10; - //const int64_t ncs = ne00; - //const int64_t nr = ne01; - //const int64_t n_t = ne1; - //const int64_t n_s = ne2; + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; - device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); - device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); - device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2); + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); float sumf = 0.0f; @@ -1303,7 +1245,6 @@ kernel void kernel_ssm_conv_f32( } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32 -// TODO: optimize kernel void kernel_ssm_scan_f32( device const void * src0, device const void * src1, @@ -1312,48 +1253,27 @@ kernel void kernel_ssm_scan_f32( device const void * src4, device const void * src5, device float * dst, - constant int64_t & d_state, - constant int64_t & d_inner, - constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb20, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb30, - constant uint64_t & nb31, - constant uint64_t & nb40, - constant uint64_t & nb41, - constant uint64_t & nb42, - constant uint64_t & nb50, - constant uint64_t & nb51, - constant uint64_t & nb52, + constant ggml_metal_kargs_ssm_scan & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { const int64_t ir = tgpig.x; const int64_t i3 = tgpig.y; - const int64_t nc = d_state; - //const int64_t nr = d_inner; - const int64_t n_t = n_seq_tokens; - //const int64_t n_s = n_seqs; + const int64_t nc = args.d_state; + // const int64_t nr = args.d_inner; + const int64_t n_t = args.n_seq_tokens; + // const int64_t n_s = args.n_seqs; for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); - device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); + device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13); if (i2 > 0) { s0 = s; @@ -1546,22 +1466,15 @@ kernel void kernel_rms_norm( kernel void kernel_group_norm( device const float * src0, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int32_t & n_groups, - constant float & eps, + constant ggml_metal_kargs_group_norm & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - const int64_t ne = ne00*ne01*ne02; - const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + const int64_t ne = args.ne00*args.ne01*args.ne02; + const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups); int start = tgpig * gs; int end = start + gs; @@ -1625,7 +1538,7 @@ kernel void kernel_group_norm( } const float variance = tmp / gs; - const float scale = 1.0f/sqrt(variance + eps); + const float scale = 1.0f/sqrt(variance + args.eps); for (int j = start; j < end; j += ntg) { dst[j] *= scale; } @@ -2589,17 +2502,7 @@ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_ typedef void (im2col_t)( device const float * x, device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, + constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2609,17 +2512,7 @@ template kernel void kernel_im2col( device const float * x, device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, + constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2640,17 +2533,17 @@ kernel void kernel_im2col( const int64_t ioh = tgpig[1]; const int64_t iow = tgpig[2]; - const int64_t iiw = iow*s0 + ikw*d0 - p0; - const int64_t iih = ioh*s1 + ikh*d1 - p1; + const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; + const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; - const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw); + const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); device T * pdst = (device T *) (dst); - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { pdst[offset_dst] = 0.0f; } else { - const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw; + const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; pdst[offset_dst] = x[offset_src]; } } @@ -2661,20 +2554,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; typedef void (im2col_ext_t)( device const float * x, device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - constant int32_t & N, - constant int32_t & KH, - constant int32_t & KW, + constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2684,53 +2564,40 @@ template kernel void kernel_im2col_ext( device const float * x, device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - constant int32_t & N, - constant int32_t & KH, - constant int32_t & KW, + constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] - const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] + const int64_t KHW = (int64_t)args.KHW; - const int64_t d = tgpig[0] / CHW; - const int64_t chw = tgpig[0] % CHW; + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) const int64_t HW = tgpig[0] % KHW; const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; - if (tpitg_0 >= N) { + if (tpitg_0 >= args.N) { return; } - const int64_t tpitg_1 = HW / KW; - const int64_t tpitg_2 = HW % KW; + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; - const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; - const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; const int64_t offset_dst = - (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + - (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); device T * pdst = (device T *) (dst); - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { pdst[offset_dst] = 0.0f; } else { - const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; - pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; } } @@ -2741,12 +2608,7 @@ typedef void (conv_transpose_1d_t)( device const float * src0, device const float * src1, device char * dst, - constant int32_t & IC, - constant int32_t & IL, - constant int32_t & K, - constant int32_t & s0, - constant uint64_t & nb0, - constant uint64_t & nb1, + constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); @@ -2755,29 +2617,24 @@ kernel void kernel_conv_transpose_1d( device const T * src0, device const float * src1, device char * dst, - constant int32_t & IC, - constant int32_t & IL, - constant int32_t & K, - constant int32_t & s0, - constant uint64_t & nb0, - constant uint64_t & nb1, + constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { float v = 0.0f; - for (int64_t c = 0; c < IC; c++) { - const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; - const int32_t input_offset = c * IL; + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; + const int32_t input_offset = c * args.IL; - for (int64_t i = 0; i < IL; i++) { - if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) { - v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i]; + for (int64_t i = 0; i < args.IL; i++) { + if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { + v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; } } } - device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1); + device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1); dst_ptr[0] = v; } @@ -2787,12 +2644,7 @@ kernel void kernel_conv_transpose_1d( device const float * src0, device const float * src1, device char * dst, - constant int32_t & IC, - constant int32_t & IL, - constant int32_t & K, - constant int32_t & s0, - constant uint64_t & nb0, - constant uint64_t & nb1, + constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); @@ -2801,38 +2653,14 @@ kernel void kernel_conv_transpose_1d( device const half * src0, device const float * src1, device char * dst, - constant int32_t & IC, - constant int32_t & IL, - constant int32_t & K, - constant int32_t & s0, - constant uint64_t & nb0, - constant uint64_t & nb1, + constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); kernel void kernel_upscale_f32( device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & sf0, - constant float & sf1, - constant float & sf2, - constant float & sf3, + constant ggml_metal_kargs_upscale & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -2841,15 +2669,15 @@ kernel void kernel_upscale_f32( const int64_t i2 = tgpig.y; const int64_t i1 = tgpig.x; - const int64_t i03 = i3/sf3; - const int64_t i02 = i2/sf2; - const int64_t i01 = i1/sf1; + const int64_t i03 = i3/args.sf3; + const int64_t i02 = i2/args.sf2; + const int64_t i01 = i1/args.sf1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int64_t i00 = i0/sf0; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int64_t i00 = i0/args.sf0; - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_ptr[0] = src0_ptr[0]; } @@ -2858,22 +2686,7 @@ kernel void kernel_upscale_f32( kernel void kernel_pad_f32( device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, + constant ggml_metal_kargs_pad & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -2886,12 +2699,12 @@ kernel void kernel_pad_f32( const int64_t i02 = i2; const int64_t i01 = i1; - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00) { + if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00) { dst_ptr[i0] = src0_ptr[i0]; } else { dst_ptr[i0] = 0.0f; @@ -2901,7 +2714,7 @@ kernel void kernel_pad_f32( return; } - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { dst_ptr[i0] = 0.0f; } } @@ -2909,21 +2722,7 @@ kernel void kernel_pad_f32( kernel void kernel_pad_reflect_1d_f32( device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant int64_t & ne0, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & p0, - constant int32_t & p1, + constant ggml_metal_kargs_pad_reflect_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2937,17 +2736,17 @@ kernel void kernel_pad_reflect_1d_f32( const int64_t i02 = i2; const int64_t i01 = i1; - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < p0) { - dst_ptr[i0] = src0_ptr[p0 - i0]; - } else if (i0 < ne0 - p1) { - dst_ptr[i0] = src0_ptr[i0 - p0]; + if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.p0) { + dst_ptr[i0] = src0_ptr[args.p0 - i0]; + } else if (i0 < args.ne0 - args.p1) { + dst_ptr[i0] = src0_ptr[i0 - args.p0]; } else { - dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1]; + dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1]; } } } @@ -2955,44 +2754,40 @@ kernel void kernel_pad_reflect_1d_f32( kernel void kernel_arange_f32( device char * dst, - constant int64_t & ne0, - constant float & start, - constant float & step, + constant ggml_metal_kargs_arange & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { device float * dst_ptr = (device float *) dst; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = start + step * i0; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + dst_ptr[i0] = args.start + args.step * i0; } } kernel void kernel_timestep_embedding_f32( device const char * src0, device char * dst, - constant uint64_t & nb1, - constant int & dim, - constant int & max_period, + constant ggml_metal_kargs_timestep_embedding & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { int i = tgpig.x; - device float * embed_data = (device float *)(dst + i*nb1); + device float * embed_data = (device float *)(dst + i*args.nb1); - int half_ = dim / 2; + int half_ = args.dim / 2; for (int j = tpitg.x; j < half_; j += ntg.x) { float timestep = ((device float *)src0)[i]; - float freq = (float)exp(-log((float)max_period) * j / half_); + float freq = (float)exp(-log((float)args.max_period) * j / half_); float arg = timestep * freq; embed_data[j ] = cos(arg); embed_data[j + half_] = sin(arg); } - if (dim % 2 != 0 && tpitg.x == 0) { - embed_data[dim] = 0.f; + if (args.dim % 2 != 0 && tpitg.x == 0) { + embed_data[args.dim] = 0.f; } } @@ -3000,8 +2795,7 @@ kernel void kernel_timestep_embedding_f32( typedef void (argsort_t)( device const float * x, device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, + constant ggml_metal_kargs_argsort & args, threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); @@ -3010,8 +2804,7 @@ template kernel void kernel_argsort_f32_i32( device const float * x, device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, + constant ggml_metal_kargs_argsort & args, threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { @@ -3019,9 +2812,9 @@ kernel void kernel_argsort_f32_i32( int col = tpitg[0]; int row = tgpig[1]; - if (col >= ncols_pad) return; + if (col >= args.ncols_pad) return; - device const float * x_row = x + row * ncols; + device const float * x_row = x + row * args.ncols; threadgroup int32_t * dst_row = shared_values; // initialize indices @@ -3029,21 +2822,21 @@ kernel void kernel_argsort_f32_i32( threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= ncols_pad; k *= 2) { + for (int k = 2; k <= args.ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + if (dst_row[col] >= args.ncols || + (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]])) ) { SWAP(dst_row[col], dst_row[ixj]); } } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + if (dst_row[ixj] >= args.ncols || + (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]])) ) { @@ -3056,8 +2849,8 @@ kernel void kernel_argsort_f32_i32( } // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; + if (col < args.ncols) { + dst[row * args.ncols + col] = dst_row[col]; } } @@ -3067,9 +2860,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar kernel void kernel_leaky_relu_f32( device const float * src0, device float * dst, - constant float & slope, + constant ggml_metal_kargs_leaky_relu & args, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; } // ref: https://arxiv.org/pdf/2307.08691.pdf @@ -6010,28 +5803,21 @@ kernel void kernel_get_rows_q( device const void * src0, device const void * src1, device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, + constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; const int64_t i02 = i11; - for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { + for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; } } @@ -6040,27 +5826,20 @@ kernel void kernel_get_rows_f( device const void * src0, device const void * src1, device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, + constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; const int64_t i02 = i11; - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { + (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = + ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; } } @@ -6068,27 +5847,20 @@ kernel void kernel_get_rows_i32( device const void * src0, device const void * src1, device int32_t * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, + constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; const int64_t i02 = i11; - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { + (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = + ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; } } @@ -6690,98 +6462,78 @@ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel void kernel_pool_2d_max_f32( device const float * src0, device float * dst, - constant int32_t & k0, - constant int32_t & k1, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int64_t & IH, - constant int64_t & IW, - constant int64_t & OH, - constant int64_t & OW, - constant int64_t & parallel_elements, + constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= parallel_elements) { + if (gid >= args.parallel_elements) { return; } const int idx = gid; - const int I_HW = IH * IW; - const int O_HW = OH * OW; + const int I_HW = args.IH * args.IW; + const int O_HW = args.OH * args.OW; const int nc = idx / O_HW; - const int cur_oh = idx % O_HW / OW; - const int cur_ow = idx % O_HW % OW; + const int cur_oh = idx % O_HW / args.OW; + const int cur_ow = idx % O_HW % args.OW; device const float * i_ptr = src0 + nc * I_HW; device float * o_ptr = dst + nc * O_HW; - const int start_h = cur_oh * s1 - p1; + const int start_h = cur_oh * args.s1 - args.p1; const int bh = MAX(0, start_h); - const int eh = MIN(IH, start_h + k1); - const int start_w = cur_ow * s0 - p0; + const int eh = MIN(args.IH, start_h + args.k1); + const int start_w = cur_ow * args.s0 - args.p0; const int bw = MAX(0, start_w); - const int ew = MIN(IW, start_w + k0); + const int ew = MIN(args.IW, start_w + args.k0); float res = -INFINITY; for (int i = bh; i < eh; i += 1) { for (int j = bw; j < ew; j += 1) { - res = MAX(res, i_ptr[i * IW + j]); + res = MAX(res, i_ptr[i * args.IW + j]); } } - o_ptr[cur_oh * OW + cur_ow] = res; + o_ptr[cur_oh * args.OW + cur_ow] = res; } kernel void kernel_pool_2d_avg_f32( device const float * src0, device float * dst, - constant int32_t & k0, - constant int32_t & k1, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int64_t & IH, - constant int64_t & IW, - constant int64_t & OH, - constant int64_t & OW, - constant int64_t & parallel_elements, + constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= parallel_elements) { + if (gid >= args.parallel_elements) { return; } const int idx = gid; - const int I_HW = IH * IW; - const int O_HW = OH * OW; + const int I_HW = args.IH * args.IW; + const int O_HW = args.OH * args.OW; const int nc = idx / O_HW; - const int cur_oh = idx % O_HW / OW; - const int cur_ow = idx % O_HW % OW; + const int cur_oh = idx % O_HW / args.OW; + const int cur_ow = idx % O_HW % args.OW; device const float * i_ptr = src0 + nc * I_HW; device float * o_ptr = dst + nc * O_HW; - const int start_h = cur_oh * s1 - p1; + const int start_h = cur_oh * args.s1 - args.p1; const int bh = MAX(0, start_h); - const int eh = MIN(IH, start_h + k1); - const int start_w = cur_ow * s0 - p0; + const int eh = MIN(args.IH, start_h + args.k1); + const int start_w = cur_ow * args.s0 - args.p0; const int bw = MAX(0, start_w); - const int ew = MIN(IW, start_w + k0); + const int ew = MIN(args.IW, start_w + args.k0); // const float scale = 1. / ((eh - bh) * (ew - bw)); - const float scale = 1. / (k0 * k1); + const float scale = 1. / (args.k0 * args.k1); float res = 0; for (int i = bh; i < eh; i += 1) { for (int j = bw; j < ew; j += 1) { - float cur = i_ptr[i * IW + j]; + float cur = i_ptr[i * args.IW + j]; res += cur * scale; } } - o_ptr[cur_oh * OW + cur_ow] = res; + o_ptr[cur_oh * args.OW + cur_ow] = res; } diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index bc2ea06b5..b85a895c4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -1007,17 +1007,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_ADD: case GGML_OP_SCALE: case GGML_OP_MUL: - return true; + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; } case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: case GGML_OP_NORM: case GGML_OP_RMS_NORM: @@ -2573,26 +2574,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const memcpy(&eps, dst->op_params, sizeof(float)); const int ne00 = src0 ? src0->ne[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; - GGML_ASSERT(ggml_is_contiguous_1(src0)); + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; const int nth = MIN(64, ne00); cl_kernel kernel = backend_ctx->kernel_norm; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL)); - const int64_t nrows = ggml_nrows(src0); - - size_t global_work_size[] = {(size_t)nrows*nth, 1, 1}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; #ifdef GGML_OPENCL_PROFILING @@ -2630,16 +2638,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c memcpy(&eps, dst->op_params, sizeof(float)); const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); const int nth = MIN(64, ne00); - const int64_t nrows = ggml_nrows(src0); - - size_t global_work_size[] = {(size_t)nrows*nth, 1, 1}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; cl_kernel kernel = backend_ctx->kernel_rms_norm; @@ -2654,15 +2665,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c sizeof(local_work_size), local_work_size, sizeof(size_t), &sgs, NULL)); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); // This is local memory - the size depends on subgroup size. - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs, NULL)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL)); #ifdef GGML_OPENCL_PROFILING cl_event evt; diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl index 8882a8c9c..1d43642a9 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -506,14 +506,23 @@ kernel void kernel_norm( global float * dst, ulong offsetd, int ne00, + int ne01, + int ne02, + int ne03, ulong nb01, + ulong nb02, + ulong nb03, float eps, local float * sum ) { src0 = (global void*)((global char*)src0 + offset0); dst = (global void*)((global char*)dst + offsetd); - global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01); + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); // MEAN // parallel sum @@ -533,7 +542,7 @@ kernel void kernel_norm( // recenter and VARIANCE barrier(CLK_LOCAL_MEM_FENCE); - global float * y = dst + get_group_id(0)*ne00; + global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; sum[get_local_id(0)] = 0.0f; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { y[i00] = x[i00] - mean; @@ -566,14 +575,23 @@ kernel void kernel_rms_norm( global float * dst, ulong offsetd, int ne00, + int ne01, + int ne02, + int ne03, ulong nb01, + ulong nb02, + ulong nb03, float eps, local float * sum // Note, the size depends on number of subgroups ) { src0 = (global void*)((global char*)src0 + offset0); dst = (global float*)((global char*)dst + offsetd); - global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01); + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); global float * x_scalar = (global float *) x; float4 sumf = 0; float all_sum = 0; @@ -607,7 +625,7 @@ kernel void kernel_rms_norm( const float mean = sum[0]; const float scale = 1.0f/sqrt(mean + eps); - global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00); + global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); global float * y_scalar = (global float *) y; for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { y[i00] = x[i00] * scale; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a8771cd5b..d899f30bf 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2345,6 +2345,7 @@ struct ggml_tensor * ggml_concat( struct ggml_tensor * b, int dim) { GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + GGML_ASSERT(a->type == b->type); int64_t ne[GGML_MAX_DIMS]; for (int d = 0; d < GGML_MAX_DIMS; ++d) {