Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	examples/server/README.md
This commit is contained in:
Concedo 2024-11-06 23:13:00 +08:00
commit 628dcd640e
10 changed files with 1111 additions and 741 deletions

View file

@ -379,8 +379,8 @@ struct server_queue {
std::condition_variable condition_tasks; std::condition_variable condition_tasks;
// callback functions // callback functions
std::function<void(server_task&)> callback_new_task; std::function<void(server_task)> callback_new_task;
std::function<void(void)> callback_update_slots; std::function<void(void)> callback_update_slots;
// Add a new task to the end of the queue // Add a new task to the end of the queue
int post(server_task task, bool front = false) { int post(server_task task, bool front = false) {
@ -432,7 +432,7 @@ struct server_queue {
} }
// Register function to process a new task // Register function to process a new task
void on_new_task(std::function<void(server_task &)> callback) { void on_new_task(std::function<void(server_task)> callback) {
callback_new_task = std::move(callback); callback_new_task = std::move(callback);
} }
@ -482,7 +482,7 @@ struct server_queue {
lock.unlock(); lock.unlock();
QUE_DBG("processing task, id = %d\n", task.id); QUE_DBG("processing task, id = %d\n", task.id);
callback_new_task(task); callback_new_task(std::move(task));
} }
// all tasks in the current loop is processed, slots data is now ready // all tasks in the current loop is processed, slots data is now ready
@ -645,17 +645,12 @@ struct server_context {
bool load_model(const common_params & params_) { bool load_model(const common_params & params_) {
params = params_; params = params_;
// reserve one extra sequence (seq_id == 0) for extra features
params.n_parallel += 1;
common_init_result llama_init = common_init_from_params(params); common_init_result llama_init = common_init_from_params(params);
model = llama_init.model; model = llama_init.model;
ctx = llama_init.context; ctx = llama_init.context;
loras = llama_init.lora_adapters; loras = llama_init.lora_adapters;
params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr) { if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params.model.c_str()); SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
return false; return false;
@ -1289,16 +1284,16 @@ struct server_context {
void send_embedding(const server_slot & slot, const llama_batch & batch) { void send_embedding(const server_slot & slot, const llama_batch & batch) {
server_task_result res; server_task_result res;
res.id = slot.id_task; res.id = slot.id_task;
res.error = false; res.error = false;
res.stop = true; res.stop = true;
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
std::vector<float> embd_res(n_embd, 0.0f); std::vector<float> embd_res(n_embd, 0.0f);
for (int i = 0; i < batch.n_tokens; ++i) { for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue; continue;
} }
@ -1333,12 +1328,12 @@ struct server_context {
void send_rerank(const server_slot & slot, const llama_batch & batch) { void send_rerank(const server_slot & slot, const llama_batch & batch) {
server_task_result res; server_task_result res;
res.id = slot.id_task; res.id = slot.id_task;
res.error = false; res.error = false;
res.stop = true; res.stop = true;
for (int i = 0; i < batch.n_tokens; ++i) { for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue; continue;
} }
@ -1511,7 +1506,7 @@ struct server_context {
// Functions to process the task // Functions to process the task
// //
void process_single_task(const server_task & task) { void process_single_task(server_task task) {
switch (task.type) { switch (task.type) {
case SERVER_TASK_TYPE_INFERENCE: case SERVER_TASK_TYPE_INFERENCE:
{ {
@ -1567,11 +1562,11 @@ struct server_context {
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
json slot_data = get_formated_generation(slot); json slot_data = get_formated_generation(slot);
slot_data["id"] = slot.id; slot_data["id"] = slot.id;
slot_data["id_task"] = slot.id_task; slot_data["id_task"] = slot.id_task;
slot_data["state"] = slot.state; slot_data["is_processing"] = slot.is_processing();
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens); slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
slot_data["next_token"] = { slot_data["next_token"] = {
{"has_next_token", slot.has_next_token}, {"has_next_token", slot.has_next_token},
{"has_new_line", slot.has_new_line}, {"has_new_line", slot.has_new_line},
{"n_remain", slot.n_remaining}, {"n_remain", slot.n_remaining},
@ -1582,10 +1577,10 @@ struct server_context {
{"stopping_word", slot.stopping_word}, {"stopping_word", slot.stopping_word},
}; };
if (slot_data["state"] == SLOT_STATE_IDLE) { if (slot.is_processing()) {
n_idle_slots++;
} else {
n_processing_slots++; n_processing_slots++;
} else {
n_idle_slots++;
} }
slots_data.push_back(slot_data); slots_data.push_back(slot_data);
@ -1647,7 +1642,7 @@ struct server_context {
std::string filename = task.data.at("filename"); std::string filename = task.data.at("filename");
std::string filepath = task.data.at("filepath"); std::string filepath = task.data.at("filepath");
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
const int64_t t_end = ggml_time_us(); const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0; const double t_save_ms = (t_end - t_start) / 1000.0;
@ -1689,7 +1684,7 @@ struct server_context {
slot->cache_tokens.resize(slot->n_ctx); slot->cache_tokens.resize(slot->n_ctx);
size_t token_count = 0; size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
if (nread == 0) { if (nread == 0) {
slot->cache_tokens.resize(0); slot->cache_tokens.resize(0);
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
@ -1732,7 +1727,7 @@ struct server_context {
// Erase token cache // Erase token cache
const size_t n_erased = slot->cache_tokens.size(); const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear(); slot->cache_tokens.clear();
server_task_result result; server_task_result result;
@ -1809,8 +1804,8 @@ struct server_context {
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard); llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@ -1837,7 +1832,7 @@ struct server_context {
slot.i_batch = batch.n_tokens; slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true); common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1; slot.n_past += 1;
@ -1984,8 +1979,8 @@ struct server_context {
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c); llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift); llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
for (size_t i = 0; i < n_match; i++) { for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@ -2034,9 +2029,9 @@ struct server_context {
} }
// keep only the common part // keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) { if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model) // could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
// there is no common part left // there is no common part left
slot.n_past = 0; slot.n_past = 0;
@ -2049,7 +2044,7 @@ struct server_context {
// add prompt tokens for processing in the current batch // add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false); common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);

View file

@ -260,13 +260,13 @@ async def step_wait_for_server_status(context, expecting_status: Literal['health
async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str): async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
match expected_slot_status_string: match expected_slot_status_string:
case 'idle': case 'idle':
expected_slot_status = 0 expected_slot_status = False
case 'busy': case 'busy':
expected_slot_status = 1 expected_slot_status = True
case _: case _:
assert False, "unknown status" assert False, "unknown status"
expected_slots = [{'id': slot_id, 'state': expected_slot_status} expected_slots = [{'id': slot_id, 'is_processing': expected_slot_status}
for slot_id in range(context.n_slots)] for slot_id in range(context.n_slots)]
await request_slots_status(context, expected_slots) await request_slots_status(context, expected_slots)
@ -1354,8 +1354,8 @@ async def wait_for_slots_status(context,
if status_code == 503 and status_code == expected_http_status_code: if status_code == 503 and status_code == expected_http_status_code:
return return
if status_code == 200 and status_code == expected_http_status_code: if status_code == 200 and status_code == expected_http_status_code:
n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots) n_slots_idle = sum(1 if not slot["is_processing"] else 0 for slot in slots)
n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots) n_slots_processing = sum(1 if slot["is_processing"] else 0 for slot in slots)
if ((slots_idle is None or slots_idle == n_slots_idle) if ((slots_idle is None or slots_idle == n_slots_idle)
and (slots_processing is None or slots_processing == n_slots_processing)): and (slots_processing is None or slots_processing == n_slots_processing)):
return return

View file

@ -1227,7 +1227,6 @@ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggm
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size); ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
buffer->buft = buft; buffer->buft = buft;
buffer->iface.get_name = ggml_backend_cann_host_buffer_name;
buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free; buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
return buffer; return buffer;

View file

@ -307,6 +307,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_Q8_0] = { [GGML_TYPE_Q8_0] = {
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -13718,6 +13719,13 @@ int ggml_cpu_get_sve_cnt(void) {
} }
void ggml_cpu_init(void) { void ggml_cpu_init(void) {
// needed to initialize f16 tables
{
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
ggml_critical_section_start(); ggml_critical_section_start();
static bool is_first_call = true; static bool is_first_call = true;
@ -13725,24 +13733,21 @@ void ggml_cpu_init(void) {
if (is_first_call) { if (is_first_call) {
// initialize GELU, Quick GELU, SILU and EXP F32 tables // initialize GELU, Quick GELU, SILU and EXP F32 tables
{ {
// FIXME: this may be called before ggml_init const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
//const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
for (int i = 0; i < (1 << 16); ++i) { for (int i = 0; i < (1 << 16); ++i) {
union { union {
uint16_t u16; uint16_t u16;
ggml_fp16_t fp16; ggml_fp16_t fp16;
} u = {i}; } u = {i};
// FIXME: this table is used in conversion functions outside of compute float f = GGML_FP16_TO_FP32(u.fp16);
// current code depends on ggml_init initializing this table
float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
} }
//const uint64_t t_end = ggml_time_us(); UNUSED(t_end); const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
//GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
} }
#if defined(__ARM_ARCH) #if defined(__ARM_ARCH)

View file

@ -1297,11 +1297,17 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
if (err != cudaErrorPeerAccessAlreadyEnabled) { if (err != cudaErrorPeerAccessAlreadyEnabled) {
CUDA_CHECK(err); CUDA_CHECK(err);
} else {
// reset the error
cudaGetLastError();
} }
} else { } else {
cudaError_t err = cudaDeviceDisablePeerAccess(id_other); cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
if (err != cudaErrorPeerAccessNotEnabled) { if (err != cudaErrorPeerAccessNotEnabled) {
CUDA_CHECK(err); CUDA_CHECK(err);
} else {
// reset the error
cudaGetLastError();
} }
} }
} }

View file

@ -255,9 +255,49 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16, GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@ -710,9 +750,49 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
@ -869,13 +949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
if (op->src[1]->type != GGML_TYPE_F16) { if (op->src[1]->type != op->src[2]->type) {
return false;
}
if (op->src[2]->type != GGML_TYPE_F16) {
return false;
}
if (op->src[0]->ne[0] == 256) {
return false; return false;
} }
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
@ -2822,6 +2896,7 @@ static void ggml_metal_encode_node(
GGML_ASSERT(ne11 % 32 == 0); GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == src2->type);
GGML_ASSERT(ggml_are_same_shape (src1, src2)); GGML_ASSERT(ggml_are_same_shape (src1, src2));
@ -2869,26 +2944,154 @@ static void ggml_metal_encode_node(
bool use_vec_kernel = false; bool use_vec_kernel = false;
if (ne01 >= 4 || (ne00%128 != 0)) { if (ne01 >= 4 || (ne00%128 != 0)) {
switch (ne00) { switch (src1->type) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case GGML_TYPE_F16:
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; {
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; switch (ne00) {
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q4_0:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q4_1:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q5_0:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q5_1:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
case GGML_TYPE_Q8_0:
{
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_LOG_ERROR("add template specialization for this size\n");
GGML_ABORT("add template specialization for this size");
}
}
} break;
default: default:
{ {
GGML_LOG_ERROR("unsupported size: %lld\n", ne00); GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this size\n"); GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this size"); GGML_ABORT("add template specialization for this type");
} }
} }
} else { } else {
use_vec_kernel = true; use_vec_kernel = true;
switch (ne00) { switch (ne00) {
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; case 128:
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; {
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
case 256:
{
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
default: default:
{ {
GGML_LOG_ERROR("unsupported size: %lld\n", ne00); GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
@ -2942,10 +3145,16 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0); GGML_ASSERT(ncpsg % 32 == 0);
// 16*32*(nsg)
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2; int64_t nsgmax = 2;
while (true) { while (true) {
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) { if (smem > device.maxThreadgroupMemoryLength) {
break; break;
} }
@ -2956,16 +3165,15 @@ static void ggml_metal_encode_node(
// simdgroups per threadgroup (a.k.a. warps) // simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength); //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; #undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else { } else {
// half1x4 kernel // half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
@ -2973,8 +3181,28 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0); GGML_ASSERT(ncpsg % 32 == 0);
// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the attention scores (nqptg x ncpsg) as f32
//
// 2*ne00*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
break;
}
nsgmax *= 2;
}
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps) // simdgroups per threadgroup (a.k.a. warps)
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1; int64_t nsg = 1;
while (nsg <= nsgt) { while (nsg <= nsgt) {
@ -2982,12 +3210,12 @@ static void ggml_metal_encode_node(
} }
nsg /= 2; nsg /= 2;
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength); //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; [encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} }
} break; } break;
@ -3844,7 +4072,7 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
} }
} }
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
} }
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
@ -3854,7 +4082,8 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
} }
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name; return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
UNUSED(dev); UNUSED(dev);
} }

File diff suppressed because it is too large Load diff

View file

@ -4,7 +4,7 @@
#include "ggml-quants.h" #include "ggml-quants.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#include "ggml-cpu.h"
#include <math.h> #include <math.h>
#include <string.h> #include <string.h>
@ -9105,10 +9105,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
#elif defined __AVX__ #elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i m3 = _mm_set1_epi8(3); const __m128i m3 = _mm_set1_epi8(3);
const __m128i m32s = _mm_set1_epi8(32); const __m128i m15 = _mm_set1_epi8(15);
const __m128i m2 = _mm_set1_epi8(2);
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -9120,12 +9118,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const uint8_t * restrict qh = x[i].qh; const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
// handle the q6_k -32 offset separately using bsums
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
__m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128();
__m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); int is = 0;
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
@ -9133,26 +9139,26 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@ -9163,15 +9169,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
@ -9181,32 +9178,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
p16_0 = _mm_sub_epi16(p16_0, q8s_0); const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
p16_1 = _mm_sub_epi16(p16_1, q8s_1); const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
p16_2 = _mm_sub_epi16(p16_2, q8s_2); const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
p16_3 = _mm_sub_epi16(p16_3, q8s_3); const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
p16_4 = _mm_sub_epi16(p16_4, q8s_4); is += 4;
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
@ -9215,8 +9200,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
} }
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
} }
*s = hsum_float_8(acc); *s = hsum_float_8(acc);

View file

@ -220,8 +220,10 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
void * ggml_aligned_malloc(size_t size) { void * ggml_aligned_malloc(size_t size) {
const int alignment = 64;
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
return _aligned_malloc(size, TENSOR_ALIGNMENT); return _aligned_malloc(size, alignment);
#else #else
if (size == 0) { if (size == 0) {
GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n"); GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
@ -229,8 +231,9 @@ void * ggml_aligned_malloc(size_t size) {
} }
void * aligned_memory = NULL; void * aligned_memory = NULL;
#ifdef GGML_USE_CPU_HBM #ifdef GGML_USE_CPU_HBM
int result = hbw_posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size); int result = hbw_posix_memalign(&aligned_memory, alignment, size);
#elif TARGET_OS_OSX #elif TARGET_OS_OSX
GGML_UNUSED(alignment);
kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE); kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
int result = EFAULT; int result = EFAULT;
switch (alloc_status) { switch (alloc_status) {
@ -248,7 +251,7 @@ void * ggml_aligned_malloc(size_t size) {
break; break;
} }
#else #else
int result = posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size); int result = posix_memalign(&aligned_memory, alignment, size);
#endif #endif
if (result != 0) { if (result != 0) {
// Handle allocation failure // Handle allocation failure
@ -392,6 +395,8 @@ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
16))); 16)));
} }
} }
#endif
#if defined(__AVX2__)
if (ggml_cpu_has_avx2()) { if (ggml_cpu_has_avx2()) {
for (; i + 8 <= n; i += 8) { for (; i + 8 <= n; i += 8) {
_mm256_storeu_ps(y + i, _mm256_storeu_ps(y + i,
@ -1402,11 +1407,11 @@ static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const str
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
struct ggml_context * ggml_init(struct ggml_init_params params) { struct ggml_context * ggml_init(struct ggml_init_params params) {
static bool is_first_call = false; static bool is_first_call = true;
ggml_critical_section_start(); ggml_critical_section_start();
if (!is_first_call) { if (is_first_call) {
// initialize time system (required on Windows) // initialize time system (required on Windows)
ggml_time_init(); ggml_time_init();
@ -1417,7 +1422,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
} u = {i}; } u = {i};
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
} }
is_first_call = true;
is_first_call = false;
} }
ggml_critical_section_end(); ggml_critical_section_end();

View file

@ -9182,7 +9182,7 @@ static bool llm_load_tensors(
// print memory requirements per buffer type // print memory requirements per buffer type
for (auto & buf : model.bufs) { for (auto & buf : model.bufs) {
LLAMA_LOG_INFO("%s: %10s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
} }
// populate tensors_by_name // populate tensors_by_name
@ -21893,8 +21893,11 @@ static int32_t llama_chat_apply_template_internal(
// IBM Granite template // IBM Granite template
for (const auto & message : chat) { for (const auto & message : chat) {
std::string role(message->role); std::string role(message->role);
ss << "<|start_of_role|>" << role << "<|end_of_role|>" ss << "<|start_of_role|>" << role << "<|end_of_role|>";
<< message->content << "<|end_of_text|>\n"; if (role == "assistant_tool_call") {
ss << "<|tool_call|>";
}
ss << message->content << "<|end_of_text|>\n";
} }
if (add_ass) { if (add_ass) {
ss << "<|start_of_role|>assistant<|end_of_role|>\n"; ss << "<|start_of_role|>assistant<|end_of_role|>\n";