Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	AUTHORS
#	README.md
#	ci/run.sh
#	docs/build.md
#	ggml/src/CMakeLists.txt
#	ggml/src/ggml-metal/CMakeLists.txt
#	scripts/sync-ggml.last
This commit is contained in:
Concedo 2025-03-10 10:32:41 +08:00
commit 6b7c3ae1d3
13 changed files with 1003 additions and 698 deletions

View file

@ -2572,5 +2572,43 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_examples({LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-7b-spec"},
string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.n_gpu_layers = 99;
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-14b-spec"},
string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.n_gpu_layers = 99;
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
return ctx_arg; return ctx_arg;
} }

View file

@ -1378,13 +1378,27 @@ struct ArgumentsExpression {
} }
}; };
static std::string strip(const std::string & s) { static std::string strip(const std::string & s, const std::string & chars = "", bool left = true, bool right = true) {
auto start = s.find_first_not_of(" \t\n\r"); auto charset = chars.empty() ? " \t\n\r" : chars;
auto start = left ? s.find_first_not_of(charset) : 0;
if (start == std::string::npos) return ""; if (start == std::string::npos) return "";
auto end = s.find_last_not_of(" \t\n\r"); auto end = right ? s.find_last_not_of(charset) : s.size() - 1;
return s.substr(start, end - start + 1); return s.substr(start, end - start + 1);
} }
static std::vector<std::string> split(const std::string & s, const std::string & sep) {
std::vector<std::string> result;
size_t start = 0;
size_t end = s.find(sep);
while (end != std::string::npos) {
result.push_back(s.substr(start, end - start));
start = end + sep.length();
end = s.find(sep, start);
}
result.push_back(s.substr(start));
return result;
}
static std::string capitalize(const std::string & s) { static std::string capitalize(const std::string & s) {
if (s.empty()) return s; if (s.empty()) return s;
auto result = s; auto result = s;
@ -1467,8 +1481,26 @@ public:
} else if (obj.is_string()) { } else if (obj.is_string()) {
auto str = obj.get<std::string>(); auto str = obj.get<std::string>();
if (method->get_name() == "strip") { if (method->get_name() == "strip") {
vargs.expectArgs("strip method", {0, 0}, {0, 0}); vargs.expectArgs("strip method", {0, 1}, {0, 0});
return Value(strip(str)); auto chars = vargs.args.empty() ? "" : vargs.args[0].get<std::string>();
return Value(strip(str, chars));
} else if (method->get_name() == "lstrip") {
vargs.expectArgs("lstrip method", {0, 1}, {0, 0});
auto chars = vargs.args.empty() ? "" : vargs.args[0].get<std::string>();
return Value(strip(str, chars, /* left= */ true, /* right= */ false));
} else if (method->get_name() == "rstrip") {
vargs.expectArgs("rstrip method", {0, 1}, {0, 0});
auto chars = vargs.args.empty() ? "" : vargs.args[0].get<std::string>();
return Value(strip(str, chars, /* left= */ false, /* right= */ true));
} else if (method->get_name() == "split") {
vargs.expectArgs("split method", {1, 1}, {0, 0});
auto sep = vargs.args[0].get<std::string>();
auto parts = split(str, sep);
Value result = Value::array();
for (const auto& part : parts) {
result.push_back(Value(part));
}
return result;
} else if (method->get_name() == "capitalize") { } else if (method->get_name() == "capitalize") {
vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
return Value(capitalize(str)); return Value(capitalize(str));

View file

@ -1312,7 +1312,7 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
} }
bool can_batch_with(server_slot & other_slot) { bool can_batch_with(server_slot & other_slot) const {
return is_non_causal() == other_slot.is_non_causal() return is_non_causal() == other_slot.is_non_causal()
&& are_lora_equal(lora, other_slot.lora); && are_lora_equal(lora, other_slot.lora);
} }
@ -1900,6 +1900,7 @@ struct server_context {
try { try {
common_chat_format_example(chat_templates.get(), params.use_jinja); common_chat_format_example(chat_templates.get(), params.use_jinja);
} catch (const std::exception & e) { } catch (const std::exception & e) {
SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_init(model, "chatml"); chat_templates = common_chat_templates_init(model, "chatml");
} }
@ -2156,14 +2157,6 @@ struct server_context {
} }
if (slot.has_new_line) { if (slot.has_new_line) {
// if we have already seen a new line, we stop after a certain time limit
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
}
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
if (slot.params.n_indent > 0) { if (slot.params.n_indent > 0) {
// check the current indentation // check the current indentation
@ -2202,6 +2195,14 @@ struct server_context {
// check if there is a new line in the generated text // check if there is a new line in the generated text
if (result.text_to_send.find('\n') != std::string::npos) { if (result.text_to_send.find('\n') != std::string::npos) {
slot.has_new_line = true; slot.has_new_line = true;
// if we have seen a new line, we stop after a certain time limit, but only upon another new line
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
}
} }
// if context shift is disabled, we stop when it reaches the context limit // if context shift is disabled, we stop when it reaches the context limit

View file

@ -76,7 +76,14 @@ namespace fs = std::filesystem;
static std::string path_str(const fs::path & path) { static std::string path_str(const fs::path & path) {
std::string u8path; std::string u8path;
try { try {
#if defined(__cpp_lib_char8_t)
// C++20 and later: u8string() returns std::u8string
std::u8string u8str = path.u8string();
u8path = std::string(reinterpret_cast<const char*>(u8str.c_str()));
#else
// C++17: u8string() returns std::string
u8path = path.u8string(); u8path = path.u8string();
#endif
} catch (...) { } catch (...) {
} }
return u8path; return u8path;

View file

@ -11719,9 +11719,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
#elif defined __AVX2__ #elif defined __AVX2__
const __m256i mask = _mm256_set1_epi16(2 * 0x7); const __m256i mask = _mm256_set1_epi16(0x7);
const __m256i mone = _mm256_set1_epi16(1); const __m256i mone = _mm256_set1_epi16(1);
const __m256i mone8 = _mm256_set1_epi8(1); const __m256i mone8 = _mm256_set1_epi8(1);
const __m256i mtwo8 = _mm256_set1_epi8(2);
// VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.
const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0);
__m256 accum1 = _mm256_setzero_ps(); __m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps();
@ -11733,6 +11736,14 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
const uint16_t * sc = (const uint16_t *)x[i].scales; const uint16_t * sc = (const uint16_t *)x[i].scales;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
// Extract 3-bit scales (16 values)
__m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc);
scales = _mm256_srlv_epi64(scales, scales_shift);
scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone);
// Indices to repeat each scale 8 times.
__m256i scales_idx1 = _mm256_set1_epi16(0x0100);
__m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8));
__m256i sumi1 = _mm256_setzero_si256(); __m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256();
@ -11778,11 +11789,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1)); const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));
const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2)); const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));
__m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 2), _mm_set1_epi16(sc[ib/2] << 1)); __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1);
__m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 8), _mm_set1_epi16(sc[ib/2] >> 5)); __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2);
scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8);
scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8);
scale1 = _mm256_add_epi16(_mm256_and_si256(scale1, mask), mone);
scale2 = _mm256_add_epi16(_mm256_and_si256(scale2, mask), mone);
const __m256i p1 = _mm256_madd_epi16(dot1, scale1); const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
const __m256i p2 = _mm256_madd_epi16(dot2, scale2); const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
const __m256i p3 = _mm256_madd_epi16(dot3, scale1); const __m256i p3 = _mm256_madd_epi16(dot3, scale1);

View file

@ -6678,6 +6678,135 @@ static void ggml_compute_forward_repeat_back(
// ggml_compute_forward_concat // ggml_compute_forward_concat
static void ggml_compute_forward_concat_any(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const size_t len = ggml_type_size(src0->type);
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t dim = ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(dim >= 0 && dim < 4);
int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];
const char * x;
// TODO: smarter multi-theading
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
} else {
x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
}
char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
memcpy(y, x, len);
}
}
}
}
}
static void ggml_compute_forward_concat_i8(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t dim = ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(dim >= 0 && dim < 4);
int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];
const int8_t * x;
// TODO: smarter multi-theading
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
} else {
x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
}
int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
*y = *x;
}
}
}
}
}
static void ggml_compute_forward_concat_f16(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
const int ith = params->ith;
const int nth = params->nth;
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t dim = ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(dim >= 0 && dim < 4);
int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];
const ggml_fp16_t * x;
// TODO: smarter multi-theading
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
} else {
x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
}
ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
*y = *x;
}
}
}
}
}
static void ggml_compute_forward_concat_f32( static void ggml_compute_forward_concat_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
@ -6685,7 +6814,7 @@ static void ggml_compute_forward_concat_f32(
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -6728,6 +6857,16 @@ static void ggml_compute_forward_concat(
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_I16:
{
ggml_compute_forward_concat_f16(params, dst);
} break;
case GGML_TYPE_I8:
{
ggml_compute_forward_concat_i8(params, dst);
} break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_I32: case GGML_TYPE_I32:
{ {
@ -6735,7 +6874,7 @@ static void ggml_compute_forward_concat(
} break; } break;
default: default:
{ {
GGML_ABORT("fatal error"); ggml_compute_forward_concat_any(params, dst);
} }
} }
} }

View file

@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
} }
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta: // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (cc == GGML_CUDA_CC_VOLTA) { if (fp16_mma_available(cc) && !new_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return; return;
} }

View file

@ -285,4 +285,239 @@ typedef struct {
float eps; float eps;
} ggml_metal_kargs_rms_norm; } ggml_metal_kargs_rms_norm;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int32_t n_groups;
float eps;
} ggml_metal_kargs_group_norm;
typedef struct {
int32_t IC;
int32_t IL;
int32_t K;
int32_t s0;
uint64_t nb0;
uint64_t nb1;
} ggml_metal_kargs_conv_transpose_1d;
typedef struct {
uint64_t ofs0;
uint64_t ofs1;
int32_t IW;
int32_t IH;
int32_t CHW;
int32_t s0;
int32_t s1;
int32_t p0;
int32_t p1;
int32_t d0;
int32_t d1;
int32_t N;
int32_t KH;
int32_t KW;
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
} ggml_metal_kargs_im2col;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne10;
int64_t ne11;
int64_t ne12;
int64_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_sum_rows;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
float scale;
float max_bias;
float m0;
float m1;
uint32_t n_head_log2;
} ggml_metal_kargs_soft_max;
typedef struct {
int64_t ne00;
int64_t ne01;
int n_past;
} ggml_metal_kargs_diag_mask_inf;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
int64_t ne11;
uint64_t nb10;
uint64_t nb11;
int64_t ne0;
int64_t ne1;
int64_t ne2;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_ssm_conv;
typedef struct {
int64_t d_state;
int64_t d_inner;
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb30;
uint64_t nb31;
uint64_t nb40;
uint64_t nb41;
uint64_t nb42;
uint64_t nb50;
uint64_t nb51;
uint64_t nb52;
} ggml_metal_kargs_ssm_scan;
typedef struct {
int64_t ne00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
uint64_t nb10;
uint64_t nb11;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_get_rows;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float sf0;
float sf1;
float sf2;
float sf3;
} ggml_metal_kargs_upscale;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_pad;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t p0;
int32_t p1;
} ggml_metal_kargs_pad_reflect_1d;
typedef struct {
uint64_t nb1;
int dim;
int max_period;
} ggml_metal_kargs_timestep_embedding;
typedef struct {
float slope;
} ggml_metal_kargs_leaky_relu;
typedef struct {
int64_t ncols;
int64_t ncols_pad;
} ggml_metal_kargs_argsort;
typedef struct {
int64_t ne0;
float start;
float step;
} ggml_metal_kargs_arange;
typedef struct {
int32_t k0;
int32_t k1;
int32_t s0;
int32_t s1;
int32_t p0;
int32_t p1;
int64_t IH;
int64_t IW;
int64_t OH;
int64_t OW;
int64_t parallel_elements;
} ggml_metal_kargs_pool_2d;
#endif // GGML_METAL_IMPL #endif // GGML_METAL_IMPL

View file

@ -1945,34 +1945,38 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
// TODO: add ggml_metal_kargs struct
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -2021,8 +2025,17 @@ static void ggml_metal_encode_node(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_soft_max args = {
// TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) /*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
/*.m1 =*/ m1,
/*.n_head_log2 =*/ n_head_log2,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) { if (id_src1) {
@ -2031,14 +2044,7 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
} }
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
@ -2056,13 +2062,16 @@ static void ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
} }
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_diag_mask_inf args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.n_past =*/ n_past,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
if (ne00%8 == 0) { if (ne00%8 == 0) {
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@ -2081,27 +2090,30 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_ssm_conv args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -2152,7 +2164,31 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_ssm_scan args = {
/*.d_state =*/ d_state,
/*.d_inner =*/ d_inner,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb30 =*/ nb30,
/*.nb31 =*/ nb31,
/*.nb40 =*/ nb40,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.nb50 =*/ nb50,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -2161,30 +2197,7 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6]; [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBytes:&args length:sizeof(args) atIndex:7];
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -3041,19 +3054,22 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("not implemented"); default: GGML_ABORT("not implemented");
} }
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_get_rows args = {
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne10 =*/ ne10,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break; } break;
@ -3110,18 +3126,21 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_group_norm args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.n_groups =*/ n_groups,
/*.eps =*/ eps,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@ -3279,8 +3298,8 @@ static void ggml_metal_encode_node(
const int32_t CHW = IC * KH * KW; const int32_t CHW = IC * KH * KW;
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
@ -3302,27 +3321,30 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("fatal error"); default: GGML_ABORT("fatal error");
}; };
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_im2col args = {
/*.ofs0 =*/ ofs0,
/*.ofs1 =*/ ofs1,
/*.IW =*/ IW,
/*.IH =*/ IH,
/*.CHW =*/ CHW,
/*.s0 =*/ s0,
/*.s1 =*/ s1,
/*.p0 =*/ p0,
/*.p1 =*/ p1,
/*.d0 =*/ d0,
/*.d1 =*/ d1,
/*.N =*/ N,
/*.KH =*/ KH,
/*.KW =*/ KW,
/*.KHW =*/ KH * KW,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
if (is_gt_mttpt) { if (is_gt_mttpt) {
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
@ -3362,16 +3384,20 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("fatal error"); default: GGML_ABORT("fatal error");
}; };
ggml_metal_kargs_conv_transpose_1d args = {
/*.IC =*/ IC,
/*.IL =*/ IL,
/*.K =*/ K,
/*.s0 =*/ s0,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&IC length:sizeof( int32_t) atIndex:3]; [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -3386,30 +3412,33 @@ static void ggml_metal_encode_node(
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_upscale args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.sf0 =*/ sf0,
/*.sf1 =*/ sf1,
/*.sf2 =*/ sf2,
/*.sf3 =*/ sf3
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
@ -3421,26 +3450,29 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_pad args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
const int nth = MIN(1024, ne0); const int nth = MIN(1024, ne0);
@ -3455,24 +3487,31 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
ggml_metal_kargs_pad_reflect_1d args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.p0 =*/ p0,
/*.p1 =*/ p1
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
const int nth = MIN(1024, ne0); const int nth = MIN(1024, ne0);
@ -3490,12 +3529,15 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_arange args = {
/*.ne0 =*/ ne0,
/*.start =*/ start,
/*.step =*/ step
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_dst offset:offs_dst atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; [encoder setBytes:&args length:sizeof(args) atIndex:1];
[encoder setBytes:&start length:sizeof(start) atIndex:2];
[encoder setBytes:&step length:sizeof(step) atIndex:3];
const int nth = MIN(1024, ne0); const int nth = MIN(1024, ne0);
@ -3512,13 +3554,16 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_timestep_embedding args = {
/*.nb1 =*/ nb1,
/*.dim =*/ dim,
/*.max_period =*/ max_period
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
const int nth = MIN(1024, half); const int nth = MIN(1024, half);
@ -3551,12 +3596,15 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("fatal error"); default: GGML_ABORT("fatal error");
}; };
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_argsort args = {
/*.ncols =*/ ne00,
/*.ncols_pad =*/ ne00_padded
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
@ -3570,11 +3618,14 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_leaky_relu args = {
/*.slope =*/ slope
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&slope length:sizeof(slope) atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:2];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst);
@ -4150,21 +4201,24 @@ static void ggml_metal_encode_node(
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
// TODO: add ggml_metal_kargs struct ggml_metal_kargs_pool_2d args_pool_2d = {
/* .k0 = */ k0,
/* .k1 = */ k1,
/* .s0 = */ s0,
/* .s1 = */ s1,
/* .p0 = */ p0,
/* .p1 = */ p1,
/* .IH = */ IH,
/* .IW = */ IW,
/* .OH = */ OH,
/* .OW = */ OW,
/* .parallel_elements = */ parallel_elements
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
[encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break; } break;

File diff suppressed because it is too large Load diff

View file

@ -1007,17 +1007,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_MUL: case GGML_OP_MUL:
return true; return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) { switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default: default:
return false; return false;
} }
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
@ -2573,26 +2574,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
const int ne00 = src0 ? src0->ne[0] : 0; const int ne00 = src0 ? src0->ne[0] : 0;
const cl_ulong nb01 = src0 ? src0->nb[1] : 0; const int ne01 = src0 ? src0->ne[1] : 0;
const int ne02 = src0 ? src0->ne[2] : 0;
const int ne03 = src0 ? src0->ne[3] : 0;
GGML_ASSERT(ggml_is_contiguous_1(src0)); const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
const int nth = MIN(64, ne00); const int nth = MIN(64, ne00);
cl_kernel kernel = backend_ctx->kernel_norm; cl_kernel kernel = backend_ctx->kernel_norm;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));
const int64_t nrows = ggml_nrows(src0); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
size_t local_work_size[] = {(size_t)nth, 1, 1}; size_t local_work_size[] = {(size_t)nth, 1, 1};
#ifdef GGML_OPENCL_PROFILING #ifdef GGML_OPENCL_PROFILING
@ -2630,16 +2638,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
const int ne00 = src0 ? src0->ne[0] : 0; const int ne00 = src0 ? src0->ne[0] : 0;
const int ne01 = src0 ? src0->ne[1] : 0;
const int ne02 = src0 ? src0->ne[2] : 0;
const int ne03 = src0 ? src0->ne[3] : 0;
const cl_ulong nb01 = src0 ? src0->nb[1] : 0; const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(src0));
const int nth = MIN(64, ne00); const int nth = MIN(64, ne00);
const int64_t nrows = ggml_nrows(src0); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
size_t local_work_size[] = {(size_t)nth, 1, 1}; size_t local_work_size[] = {(size_t)nth, 1, 1};
cl_kernel kernel = backend_ctx->kernel_rms_norm; cl_kernel kernel = backend_ctx->kernel_rms_norm;
@ -2654,15 +2665,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
sizeof(local_work_size), local_work_size, sizeof(local_work_size), local_work_size,
sizeof(size_t), &sgs, NULL)); sizeof(size_t), &sgs, NULL));
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
// This is local memory - the size depends on subgroup size. // This is local memory - the size depends on subgroup size.
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs, NULL)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL));
#ifdef GGML_OPENCL_PROFILING #ifdef GGML_OPENCL_PROFILING
cl_event evt; cl_event evt;

View file

@ -506,14 +506,23 @@ kernel void kernel_norm(
global float * dst, global float * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01, ulong nb01,
ulong nb02,
ulong nb03,
float eps, float eps,
local float * sum local float * sum
) { ) {
src0 = (global void*)((global char*)src0 + offset0); src0 = (global void*)((global char*)src0 + offset0);
dst = (global void*)((global char*)dst + offsetd); dst = (global void*)((global char*)dst + offsetd);
global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01); int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
// MEAN // MEAN
// parallel sum // parallel sum
@ -533,7 +542,7 @@ kernel void kernel_norm(
// recenter and VARIANCE // recenter and VARIANCE
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
global float * y = dst + get_group_id(0)*ne00; global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
sum[get_local_id(0)] = 0.0f; sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
y[i00] = x[i00] - mean; y[i00] = x[i00] - mean;
@ -566,14 +575,23 @@ kernel void kernel_rms_norm(
global float * dst, global float * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01, ulong nb01,
ulong nb02,
ulong nb03,
float eps, float eps,
local float * sum // Note, the size depends on number of subgroups local float * sum // Note, the size depends on number of subgroups
) { ) {
src0 = (global void*)((global char*)src0 + offset0); src0 = (global void*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd); dst = (global float*)((global char*)dst + offsetd);
global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01); int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
global float * x_scalar = (global float *) x; global float * x_scalar = (global float *) x;
float4 sumf = 0; float4 sumf = 0;
float all_sum = 0; float all_sum = 0;
@ -607,7 +625,7 @@ kernel void kernel_rms_norm(
const float mean = sum[0]; const float mean = sum[0];
const float scale = 1.0f/sqrt(mean + eps); const float scale = 1.0f/sqrt(mean + eps);
global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00); global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global float * y_scalar = (global float *) y; global float * y_scalar = (global float *) y;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
y[i00] = x[i00] * scale; y[i00] = x[i00] * scale;

View file

@ -2345,6 +2345,7 @@ struct ggml_tensor * ggml_concat(
struct ggml_tensor * b, struct ggml_tensor * b,
int dim) { int dim) {
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
GGML_ASSERT(a->type == b->type);
int64_t ne[GGML_MAX_DIMS]; int64_t ne[GGML_MAX_DIMS];
for (int d = 0; d < GGML_MAX_DIMS; ++d) { for (int d = 0; d < GGML_MAX_DIMS; ++d) {