mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 18:09:42 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # examples/server/README.md
This commit is contained in:
commit
628dcd640e
10 changed files with 1111 additions and 741 deletions
|
@ -379,7 +379,7 @@ struct server_queue {
|
||||||
std::condition_variable condition_tasks;
|
std::condition_variable condition_tasks;
|
||||||
|
|
||||||
// callback functions
|
// callback functions
|
||||||
std::function<void(server_task&)> callback_new_task;
|
std::function<void(server_task)> callback_new_task;
|
||||||
std::function<void(void)> callback_update_slots;
|
std::function<void(void)> callback_update_slots;
|
||||||
|
|
||||||
// Add a new task to the end of the queue
|
// Add a new task to the end of the queue
|
||||||
|
@ -432,7 +432,7 @@ struct server_queue {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register function to process a new task
|
// Register function to process a new task
|
||||||
void on_new_task(std::function<void(server_task &)> callback) {
|
void on_new_task(std::function<void(server_task)> callback) {
|
||||||
callback_new_task = std::move(callback);
|
callback_new_task = std::move(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -482,7 +482,7 @@ struct server_queue {
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
|
|
||||||
QUE_DBG("processing task, id = %d\n", task.id);
|
QUE_DBG("processing task, id = %d\n", task.id);
|
||||||
callback_new_task(task);
|
callback_new_task(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
// all tasks in the current loop is processed, slots data is now ready
|
// all tasks in the current loop is processed, slots data is now ready
|
||||||
|
@ -645,17 +645,12 @@ struct server_context {
|
||||||
bool load_model(const common_params & params_) {
|
bool load_model(const common_params & params_) {
|
||||||
params = params_;
|
params = params_;
|
||||||
|
|
||||||
// reserve one extra sequence (seq_id == 0) for extra features
|
|
||||||
params.n_parallel += 1;
|
|
||||||
|
|
||||||
common_init_result llama_init = common_init_from_params(params);
|
common_init_result llama_init = common_init_from_params(params);
|
||||||
|
|
||||||
model = llama_init.model;
|
model = llama_init.model;
|
||||||
ctx = llama_init.context;
|
ctx = llama_init.context;
|
||||||
loras = llama_init.lora_adapters;
|
loras = llama_init.lora_adapters;
|
||||||
|
|
||||||
params.n_parallel -= 1; // but be sneaky about it
|
|
||||||
|
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
|
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
|
||||||
return false;
|
return false;
|
||||||
|
@ -1298,7 +1293,7 @@ struct server_context {
|
||||||
std::vector<float> embd_res(n_embd, 0.0f);
|
std::vector<float> embd_res(n_embd, 0.0f);
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1338,7 +1333,7 @@ struct server_context {
|
||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1511,7 +1506,7 @@ struct server_context {
|
||||||
// Functions to process the task
|
// Functions to process the task
|
||||||
//
|
//
|
||||||
|
|
||||||
void process_single_task(const server_task & task) {
|
void process_single_task(server_task task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_INFERENCE:
|
case SERVER_TASK_TYPE_INFERENCE:
|
||||||
{
|
{
|
||||||
|
@ -1569,7 +1564,7 @@ struct server_context {
|
||||||
json slot_data = get_formated_generation(slot);
|
json slot_data = get_formated_generation(slot);
|
||||||
slot_data["id"] = slot.id;
|
slot_data["id"] = slot.id;
|
||||||
slot_data["id_task"] = slot.id_task;
|
slot_data["id_task"] = slot.id_task;
|
||||||
slot_data["state"] = slot.state;
|
slot_data["is_processing"] = slot.is_processing();
|
||||||
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
||||||
slot_data["next_token"] = {
|
slot_data["next_token"] = {
|
||||||
{"has_next_token", slot.has_next_token},
|
{"has_next_token", slot.has_next_token},
|
||||||
|
@ -1582,10 +1577,10 @@ struct server_context {
|
||||||
{"stopping_word", slot.stopping_word},
|
{"stopping_word", slot.stopping_word},
|
||||||
};
|
};
|
||||||
|
|
||||||
if (slot_data["state"] == SLOT_STATE_IDLE) {
|
if (slot.is_processing()) {
|
||||||
n_idle_slots++;
|
|
||||||
} else {
|
|
||||||
n_processing_slots++;
|
n_processing_slots++;
|
||||||
|
} else {
|
||||||
|
n_idle_slots++;
|
||||||
}
|
}
|
||||||
|
|
||||||
slots_data.push_back(slot_data);
|
slots_data.push_back(slot_data);
|
||||||
|
@ -1647,7 +1642,7 @@ struct server_context {
|
||||||
std::string filename = task.data.at("filename");
|
std::string filename = task.data.at("filename");
|
||||||
std::string filepath = task.data.at("filepath");
|
std::string filepath = task.data.at("filepath");
|
||||||
|
|
||||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
||||||
|
|
||||||
const int64_t t_end = ggml_time_us();
|
const int64_t t_end = ggml_time_us();
|
||||||
const double t_save_ms = (t_end - t_start) / 1000.0;
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
||||||
|
@ -1689,7 +1684,7 @@ struct server_context {
|
||||||
|
|
||||||
slot->cache_tokens.resize(slot->n_ctx);
|
slot->cache_tokens.resize(slot->n_ctx);
|
||||||
size_t token_count = 0;
|
size_t token_count = 0;
|
||||||
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
||||||
if (nread == 0) {
|
if (nread == 0) {
|
||||||
slot->cache_tokens.resize(0);
|
slot->cache_tokens.resize(0);
|
||||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
@ -1732,7 +1727,7 @@ struct server_context {
|
||||||
|
|
||||||
// Erase token cache
|
// Erase token cache
|
||||||
const size_t n_erased = slot->cache_tokens.size();
|
const size_t n_erased = slot->cache_tokens.size();
|
||||||
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
||||||
slot->cache_tokens.clear();
|
slot->cache_tokens.clear();
|
||||||
|
|
||||||
server_task_result result;
|
server_task_result result;
|
||||||
|
@ -1809,8 +1804,8 @@ struct server_context {
|
||||||
|
|
||||||
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);
|
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||||
|
@ -1837,7 +1832,7 @@ struct server_context {
|
||||||
|
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);
|
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
|
||||||
|
@ -1984,8 +1979,8 @@ struct server_context {
|
||||||
|
|
||||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
|
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
|
||||||
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
|
llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
|
||||||
|
|
||||||
for (size_t i = 0; i < n_match; i++) {
|
for (size_t i = 0; i < n_match; i++) {
|
||||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||||
|
@ -2034,9 +2029,9 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// keep only the common part
|
// keep only the common part
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
|
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
|
||||||
// could not partially delete (likely using a non-Transformer model)
|
// could not partially delete (likely using a non-Transformer model)
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
|
||||||
|
|
||||||
// there is no common part left
|
// there is no common part left
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
|
@ -2049,7 +2044,7 @@ struct server_context {
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
|
|
|
@ -260,13 +260,13 @@ async def step_wait_for_server_status(context, expecting_status: Literal['health
|
||||||
async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
|
async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
|
||||||
match expected_slot_status_string:
|
match expected_slot_status_string:
|
||||||
case 'idle':
|
case 'idle':
|
||||||
expected_slot_status = 0
|
expected_slot_status = False
|
||||||
case 'busy':
|
case 'busy':
|
||||||
expected_slot_status = 1
|
expected_slot_status = True
|
||||||
case _:
|
case _:
|
||||||
assert False, "unknown status"
|
assert False, "unknown status"
|
||||||
|
|
||||||
expected_slots = [{'id': slot_id, 'state': expected_slot_status}
|
expected_slots = [{'id': slot_id, 'is_processing': expected_slot_status}
|
||||||
for slot_id in range(context.n_slots)]
|
for slot_id in range(context.n_slots)]
|
||||||
await request_slots_status(context, expected_slots)
|
await request_slots_status(context, expected_slots)
|
||||||
|
|
||||||
|
@ -1354,8 +1354,8 @@ async def wait_for_slots_status(context,
|
||||||
if status_code == 503 and status_code == expected_http_status_code:
|
if status_code == 503 and status_code == expected_http_status_code:
|
||||||
return
|
return
|
||||||
if status_code == 200 and status_code == expected_http_status_code:
|
if status_code == 200 and status_code == expected_http_status_code:
|
||||||
n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots)
|
n_slots_idle = sum(1 if not slot["is_processing"] else 0 for slot in slots)
|
||||||
n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots)
|
n_slots_processing = sum(1 if slot["is_processing"] else 0 for slot in slots)
|
||||||
if ((slots_idle is None or slots_idle == n_slots_idle)
|
if ((slots_idle is None or slots_idle == n_slots_idle)
|
||||||
and (slots_processing is None or slots_processing == n_slots_processing)):
|
and (slots_processing is None or slots_processing == n_slots_processing)):
|
||||||
return
|
return
|
||||||
|
|
|
@ -1227,7 +1227,6 @@ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggm
|
||||||
|
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
|
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
|
||||||
buffer->buft = buft;
|
buffer->buft = buft;
|
||||||
buffer->iface.get_name = ggml_backend_cann_host_buffer_name;
|
|
||||||
buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
|
buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
|
||||||
|
|
||||||
return buffer;
|
return buffer;
|
||||||
|
|
|
@ -307,6 +307,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q8_0] = {
|
[GGML_TYPE_Q8_0] = {
|
||||||
|
.from_float_to_mat = quantize_mat_q8_0,
|
||||||
.vec_dot = ggml_vec_dot_q8_0_q8_0,
|
.vec_dot = ggml_vec_dot_q8_0_q8_0,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
@ -13718,6 +13719,13 @@ int ggml_cpu_get_sve_cnt(void) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cpu_init(void) {
|
void ggml_cpu_init(void) {
|
||||||
|
// needed to initialize f16 tables
|
||||||
|
{
|
||||||
|
struct ggml_init_params params = { 0, NULL, false };
|
||||||
|
struct ggml_context * ctx = ggml_init(params);
|
||||||
|
ggml_free(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_critical_section_start();
|
ggml_critical_section_start();
|
||||||
|
|
||||||
static bool is_first_call = true;
|
static bool is_first_call = true;
|
||||||
|
@ -13725,24 +13733,21 @@ void ggml_cpu_init(void) {
|
||||||
if (is_first_call) {
|
if (is_first_call) {
|
||||||
// initialize GELU, Quick GELU, SILU and EXP F32 tables
|
// initialize GELU, Quick GELU, SILU and EXP F32 tables
|
||||||
{
|
{
|
||||||
// FIXME: this may be called before ggml_init
|
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
|
||||||
//const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
|
|
||||||
|
|
||||||
for (int i = 0; i < (1 << 16); ++i) {
|
for (int i = 0; i < (1 << 16); ++i) {
|
||||||
union {
|
union {
|
||||||
uint16_t u16;
|
uint16_t u16;
|
||||||
ggml_fp16_t fp16;
|
ggml_fp16_t fp16;
|
||||||
} u = {i};
|
} u = {i};
|
||||||
// FIXME: this table is used in conversion functions outside of compute
|
float f = GGML_FP16_TO_FP32(u.fp16);
|
||||||
// current code depends on ggml_init initializing this table
|
|
||||||
float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
|
|
||||||
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
|
ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
|
||||||
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
|
ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
//const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||||
|
|
||||||
//GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
|
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__ARM_ARCH)
|
#if defined(__ARM_ARCH)
|
||||||
|
|
|
@ -1297,11 +1297,17 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
|
||||||
cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
|
cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
|
||||||
if (err != cudaErrorPeerAccessAlreadyEnabled) {
|
if (err != cudaErrorPeerAccessAlreadyEnabled) {
|
||||||
CUDA_CHECK(err);
|
CUDA_CHECK(err);
|
||||||
|
} else {
|
||||||
|
// reset the error
|
||||||
|
cudaGetLastError();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
|
cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
|
||||||
if (err != cudaErrorPeerAccessNotEnabled) {
|
if (err != cudaErrorPeerAccessNotEnabled) {
|
||||||
CUDA_CHECK(err);
|
CUDA_CHECK(err);
|
||||||
|
} else {
|
||||||
|
// reset the error
|
||||||
|
cudaGetLastError();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -255,9 +255,49 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||||
|
@ -710,9 +750,49 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
|
@ -869,13 +949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
if (op->src[1]->type != GGML_TYPE_F16) {
|
if (op->src[1]->type != op->src[2]->type) {
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[2]->type != GGML_TYPE_F16) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[0]->ne[0] == 256) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
|
@ -2822,6 +2896,7 @@ static void ggml_metal_encode_node(
|
||||||
GGML_ASSERT(ne11 % 32 == 0);
|
GGML_ASSERT(ne11 % 32 == 0);
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1->type == src2->type);
|
||||||
|
|
||||||
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
||||||
|
|
||||||
|
@ -2869,13 +2944,16 @@ static void ggml_metal_encode_node(
|
||||||
bool use_vec_kernel = false;
|
bool use_vec_kernel = false;
|
||||||
|
|
||||||
if (ne01 >= 4 || (ne00%128 != 0)) {
|
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||||
|
switch (src1->type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
{
|
||||||
switch (ne00) {
|
switch (ne00) {
|
||||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
@ -2883,12 +2961,137 @@ static void ggml_metal_encode_node(
|
||||||
GGML_ABORT("add template specialization for this size");
|
GGML_ABORT("add template specialization for this size");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
{
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||||
|
GGML_ABORT("add template specialization for this type");
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
use_vec_kernel = true;
|
use_vec_kernel = true;
|
||||||
|
|
||||||
switch (ne00) {
|
switch (ne00) {
|
||||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
case 128:
|
||||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
{
|
||||||
|
switch (src1->type) {
|
||||||
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||||
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
||||||
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
||||||
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
||||||
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
|
||||||
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||||
|
GGML_ABORT("add template specialization for this type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case 256:
|
||||||
|
{
|
||||||
|
switch (src1->type) {
|
||||||
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||||
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
||||||
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
||||||
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
||||||
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
|
||||||
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||||
|
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||||
|
GGML_ABORT("add template specialization for this type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
@ -2942,10 +3145,16 @@ static void ggml_metal_encode_node(
|
||||||
GGML_ASSERT(nqptg % 8 == 0);
|
GGML_ASSERT(nqptg % 8 == 0);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
// 16*32*(nsg)
|
||||||
|
// the shared memory needed for the simdgroups to load the KV cache
|
||||||
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||||
|
//
|
||||||
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
const size_t smem = FATTN_SMEM(nsgmax);
|
||||||
if (smem > device.maxThreadgroupMemoryLength) {
|
if (smem > device.maxThreadgroupMemoryLength) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -2956,16 +3165,15 @@ static void ggml_metal_encode_node(
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||||
|
|
||||||
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
const size_t smem = FATTN_SMEM(nsg);
|
||||||
|
|
||||||
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||||
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
#undef FATTN_SMEM
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
} else {
|
} else {
|
||||||
// half1x4 kernel
|
// half4x4 kernel
|
||||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
|
@ -2973,8 +3181,28 @@ static void ggml_metal_encode_node(
|
||||||
GGML_ASSERT(nqptg % 1 == 0);
|
GGML_ASSERT(nqptg % 1 == 0);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
// ne00 + 2*ncpsg*(nsg)
|
||||||
|
// for each query, we load it as f16 in shared memory (ne00)
|
||||||
|
// and store the attention scores (nqptg x ncpsg) as f32
|
||||||
|
//
|
||||||
|
// 2*ne00*(nsg)
|
||||||
|
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||||
|
//
|
||||||
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const size_t smem = FATTN_SMEM(nsgmax);
|
||||||
|
if (smem > device.maxThreadgroupMemoryLength) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nsgmax *= 2;
|
||||||
|
}
|
||||||
|
nsgmax /= 2;
|
||||||
|
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||||
|
|
||||||
int64_t nsg = 1;
|
int64_t nsg = 1;
|
||||||
while (nsg <= nsgt) {
|
while (nsg <= nsgt) {
|
||||||
|
@ -2982,12 +3210,12 @@ static void ggml_metal_encode_node(
|
||||||
}
|
}
|
||||||
nsg /= 2;
|
nsg /= 2;
|
||||||
|
|
||||||
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
const size_t smem = FATTN_SMEM(nsg);
|
||||||
|
|
||||||
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
|
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||||
|
#undef FATTN_SMEM
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -3844,7 +4072,7 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||||
|
@ -3854,7 +4082,8 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||||
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
|
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
||||||
|
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
||||||
|
|
||||||
UNUSED(dev);
|
UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -4,7 +4,7 @@
|
||||||
#include "ggml-quants.h"
|
#include "ggml-quants.h"
|
||||||
#include "ggml-impl.h"
|
#include "ggml-impl.h"
|
||||||
#include "ggml-cpu-impl.h"
|
#include "ggml-cpu-impl.h"
|
||||||
|
#include "ggml-cpu.h"
|
||||||
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
@ -9105,10 +9105,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
|
|
||||||
#elif defined __AVX__
|
#elif defined __AVX__
|
||||||
|
|
||||||
const __m128i m4 = _mm_set1_epi8(0xF);
|
|
||||||
const __m128i m3 = _mm_set1_epi8(3);
|
const __m128i m3 = _mm_set1_epi8(3);
|
||||||
const __m128i m32s = _mm_set1_epi8(32);
|
const __m128i m15 = _mm_set1_epi8(15);
|
||||||
const __m128i m2 = _mm_set1_epi8(2);
|
|
||||||
|
|
||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
@ -9120,12 +9118,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
const uint8_t * restrict qh = x[i].qh;
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
// handle the q6_k -32 offset separately using bsums
|
||||||
|
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
|
||||||
|
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
|
||||||
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
||||||
|
const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
|
||||||
|
const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
|
||||||
|
const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
|
||||||
|
const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
|
||||||
|
|
||||||
__m128i sumi_0 = _mm_setzero_si128();
|
__m128i sumi_0 = _mm_setzero_si128();
|
||||||
__m128i sumi_1 = _mm_setzero_si128();
|
__m128i sumi_1 = _mm_setzero_si128();
|
||||||
|
|
||||||
__m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
|
int is = 0;
|
||||||
|
|
||||||
for (int j = 0; j < QK_K/128; ++j) {
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
||||||
|
@ -9133,26 +9139,26 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
|
|
||||||
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
|
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
|
||||||
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
|
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
|
||||||
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
|
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
|
||||||
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
|
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
|
||||||
const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
|
const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
|
||||||
const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
|
const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
|
||||||
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
|
const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
|
||||||
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
|
const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
|
||||||
|
|
||||||
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
||||||
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
||||||
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
||||||
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
||||||
|
|
||||||
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
|
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
|
||||||
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
|
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
|
||||||
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
|
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
|
||||||
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
|
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
|
||||||
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
|
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
|
||||||
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
|
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
|
||||||
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
|
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
|
||||||
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
|
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
|
||||||
|
|
||||||
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
||||||
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
||||||
|
@ -9163,15 +9169,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
||||||
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
||||||
|
|
||||||
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
|
|
||||||
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
|
|
||||||
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
|
|
||||||
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
|
|
||||||
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
|
|
||||||
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
|
|
||||||
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
|
|
||||||
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
|
|
||||||
|
|
||||||
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
|
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
|
||||||
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
|
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
|
||||||
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
|
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
|
||||||
|
@ -9181,32 +9178,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
|
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
|
||||||
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
|
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
|
||||||
|
|
||||||
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
|
||||||
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
||||||
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
||||||
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
||||||
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
|
is += 4;
|
||||||
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
|
|
||||||
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
|
|
||||||
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
|
|
||||||
|
|
||||||
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
|
|
||||||
shuffle = _mm_add_epi8(shuffle, m2);
|
|
||||||
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
|
|
||||||
shuffle = _mm_add_epi8(shuffle, m2);
|
|
||||||
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
|
|
||||||
shuffle = _mm_add_epi8(shuffle, m2);
|
|
||||||
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
|
|
||||||
shuffle = _mm_add_epi8(shuffle, m2);
|
|
||||||
|
|
||||||
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
||||||
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
|
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
|
||||||
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
||||||
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
|
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
|
||||||
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
|
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
|
||||||
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
|
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
|
||||||
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
|
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
|
||||||
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
|
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
|
||||||
|
|
||||||
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
||||||
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
||||||
|
@ -9215,8 +9200,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
|
||||||
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
|
||||||
|
const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
||||||
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
*s = hsum_float_8(acc);
|
||||||
|
|
|
@ -220,8 +220,10 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
|
||||||
|
|
||||||
|
|
||||||
void * ggml_aligned_malloc(size_t size) {
|
void * ggml_aligned_malloc(size_t size) {
|
||||||
|
const int alignment = 64;
|
||||||
|
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
return _aligned_malloc(size, TENSOR_ALIGNMENT);
|
return _aligned_malloc(size, alignment);
|
||||||
#else
|
#else
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
|
GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
|
||||||
|
@ -229,8 +231,9 @@ void * ggml_aligned_malloc(size_t size) {
|
||||||
}
|
}
|
||||||
void * aligned_memory = NULL;
|
void * aligned_memory = NULL;
|
||||||
#ifdef GGML_USE_CPU_HBM
|
#ifdef GGML_USE_CPU_HBM
|
||||||
int result = hbw_posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size);
|
int result = hbw_posix_memalign(&aligned_memory, alignment, size);
|
||||||
#elif TARGET_OS_OSX
|
#elif TARGET_OS_OSX
|
||||||
|
GGML_UNUSED(alignment);
|
||||||
kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
|
kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
|
||||||
int result = EFAULT;
|
int result = EFAULT;
|
||||||
switch (alloc_status) {
|
switch (alloc_status) {
|
||||||
|
@ -248,7 +251,7 @@ void * ggml_aligned_malloc(size_t size) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
int result = posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size);
|
int result = posix_memalign(&aligned_memory, alignment, size);
|
||||||
#endif
|
#endif
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
// Handle allocation failure
|
// Handle allocation failure
|
||||||
|
@ -392,6 +395,8 @@ void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
|
||||||
16)));
|
16)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
#if defined(__AVX2__)
|
||||||
if (ggml_cpu_has_avx2()) {
|
if (ggml_cpu_has_avx2()) {
|
||||||
for (; i + 8 <= n; i += 8) {
|
for (; i + 8 <= n; i += 8) {
|
||||||
_mm256_storeu_ps(y + i,
|
_mm256_storeu_ps(y + i,
|
||||||
|
@ -1402,11 +1407,11 @@ static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const str
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct ggml_context * ggml_init(struct ggml_init_params params) {
|
struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||||
static bool is_first_call = false;
|
static bool is_first_call = true;
|
||||||
|
|
||||||
ggml_critical_section_start();
|
ggml_critical_section_start();
|
||||||
|
|
||||||
if (!is_first_call) {
|
if (is_first_call) {
|
||||||
// initialize time system (required on Windows)
|
// initialize time system (required on Windows)
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
|
||||||
|
@ -1417,7 +1422,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||||
} u = {i};
|
} u = {i};
|
||||||
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
|
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
|
||||||
}
|
}
|
||||||
is_first_call = true;
|
|
||||||
|
is_first_call = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_critical_section_end();
|
ggml_critical_section_end();
|
||||||
|
|
|
@ -9182,7 +9182,7 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
// print memory requirements per buffer type
|
// print memory requirements per buffer type
|
||||||
for (auto & buf : model.bufs) {
|
for (auto & buf : model.bufs) {
|
||||||
LLAMA_LOG_INFO("%s: %10s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// populate tensors_by_name
|
// populate tensors_by_name
|
||||||
|
@ -21893,8 +21893,11 @@ static int32_t llama_chat_apply_template_internal(
|
||||||
// IBM Granite template
|
// IBM Granite template
|
||||||
for (const auto & message : chat) {
|
for (const auto & message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
ss << "<|start_of_role|>" << role << "<|end_of_role|>"
|
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
|
||||||
<< message->content << "<|end_of_text|>\n";
|
if (role == "assistant_tool_call") {
|
||||||
|
ss << "<|tool_call|>";
|
||||||
|
}
|
||||||
|
ss << message->content << "<|end_of_text|>\n";
|
||||||
}
|
}
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue