Merge commit 'c917b67f06' into concedo_experimental

# Conflicts:
#	.devops/tools.sh
#	Makefile
#	ggml/src/ggml-cuda/mmq.cuh
#	tests/test-double-float.cpp
#	tests/test-quantize-fns.cpp
#	tests/test-quantize-perf.cpp
This commit is contained in:
Concedo 2024-07-14 11:38:20 +08:00
commit 602661ba49
25 changed files with 1339 additions and 1504 deletions

View file

@ -1203,11 +1203,10 @@ class RefactModel(Model):
# TODO: how to determine special FIM tokens automatically? # TODO: how to determine special FIM tokens automatically?
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
special_vocab._set_special_token("prefix", 1) special_vocab._set_special_token("prefix", 1)
special_vocab._set_special_token("suffix", 3) special_vocab._set_special_token("suffix", 3)
special_vocab._set_special_token("middle", 2) special_vocab._set_special_token("middle", 2)
special_vocab._set_special_token("fsep", 4) # is this correct?
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self): def set_gguf_parameters(self):

View file

@ -99,7 +99,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
char src1_str[128] = {0}; char src1_str[128] = {0};
if (src1) { if (src1) {
sprintf(src1_str, "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
} }
printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,

View file

@ -347,7 +347,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[17]; char hex_result[17];
for (int offset = 0; offset < 8; offset++) { for (int offset = 0; offset < 8; offset++) {
unsigned int shift_bits_by = (8 * (8 - offset - 1)); unsigned int shift_bits_by = (8 * (8 - offset - 1));
sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -384,7 +384,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[41] = {0}; char hex_result[41] = {0};
for (int offset = 0; offset < 20; offset++) { for (int offset = 0; offset < 20; offset++) {
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -421,7 +421,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -460,7 +460,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[17]; char hex_result[17];
for (int offset = 0; offset < 8; offset++) { for (int offset = 0; offset < 8; offset++) {
unsigned int shift_bits_by = (8 * (8 - offset - 1)); unsigned int shift_bits_by = (8 * (8 - offset - 1));
sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -490,7 +490,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[41]; char hex_result[41];
for (int offset = 0; offset < 20; offset++) { for (int offset = 0; offset < 20; offset++) {
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -520,7 +520,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff); snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
} }
if (hash_params.manifest_is_usable) { if (hash_params.manifest_is_usable) {
@ -552,7 +552,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
generate_uuidv5(result, uuid); generate_uuidv5(result, uuid);
char string_buffer[37] = {0}; char string_buffer[37] = {0};
sprintf(string_buffer, "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
uuid[0], uuid[1], uuid[2], uuid[3], uuid[0], uuid[1], uuid[2], uuid[3],
uuid[4], uuid[5], uuid[6], uuid[7], uuid[4], uuid[5], uuid[6], uuid[7],
uuid[8], uuid[9], uuid[10], uuid[11], uuid[8], uuid[9], uuid[10], uuid[11],

View file

@ -290,8 +290,13 @@ int main(int argc, char ** argv) {
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
if (add_bos) {
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_token_bos(model));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
} else {
LOG_TEE("error: input is empty\n");
return -1;
}
} }
// Tokenize negative prompt // Tokenize negative prompt

View file

@ -155,7 +155,7 @@ static void test_roundtrip_on_chunk(
} }
if (use_reference) { if (use_reference) {
qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size); qfns.from_float_ref(input_scratch, quantized_scratch, chunk_size);
} else { } else {
qfns.from_float(input_scratch, quantized_scratch, chunk_size); qfns.from_float(input_scratch, quantized_scratch, chunk_size);
} }

View file

@ -2006,6 +2006,11 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx); int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx);
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch // next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) { for (auto & slot : slots) {
@ -2176,6 +2181,14 @@ struct server_context {
} }
} }
// check that we are in the right batch_type, if not defer the slot
bool slot_type = slot.embedding ? 1 : 0;
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
continue;
}
// keep only the common part // keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past; int p0 = (int) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
@ -2277,6 +2290,9 @@ struct server_context {
{"n_tokens", batch.n_tokens}, {"n_tokens", batch.n_tokens},
}); });
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
// process the created batch of tokens // process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@ -2991,6 +3007,11 @@ int main(int argc, char ** argv) {
}; };
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
@ -3086,6 +3107,11 @@ int main(int argc, char ** argv) {
}; };
const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
@ -3158,6 +3184,11 @@ int main(int argc, char ** argv) {
}; };
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params.embedding) {
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
@ -3244,13 +3275,8 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), "application/json; charset=utf-8");
}; };
const auto handle_embeddings = [&params, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
return;
}
const json body = json::parse(req.body); const json body = json::parse(req.body);
bool is_openai = false; bool is_openai = false;

View file

@ -122,8 +122,26 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
for (size_t i = 0; i < messages.size(); ++i) { for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i]; const auto & curr_msg = messages[i];
std::string role = json_value(curr_msg, "role", std::string("")); std::string role = json_value(curr_msg, "role", std::string(""));
std::string content = json_value(curr_msg, "content", std::string(""));
std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) {
content += "\n" + part["text"].get<std::string>();
}
}
} else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
chat.push_back({role, content}); chat.push_back({role, content});
} }

View file

@ -721,7 +721,7 @@ extern "C" {
GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
GGML_API GGML_CALL int ggml_blck_size(enum ggml_type type); GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type);
GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
@ -2417,10 +2417,10 @@ extern "C" {
#endif #endif
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
typedef void (*ggml_from_float_to_mat_t)
(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
const void * GGML_RESTRICT y, size_t by, int nrc); const void * GGML_RESTRICT y, size_t by, int nrc);
typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr,
int64_t k, int64_t bx);
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc); const void * GGML_RESTRICT y, int nr, int nc);
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
@ -2428,18 +2428,18 @@ extern "C" {
typedef struct { typedef struct {
const char * type_name; const char * type_name;
int blck_size; int64_t blck_size;
int64_t blck_size_interleave; // interleave elements in blocks
size_t type_size; size_t type_size;
bool is_quantized; bool is_quantized;
ggml_to_float_t to_float; ggml_to_float_t to_float;
ggml_from_float_t from_float; ggml_from_float_t from_float;
ggml_from_float_t from_float_reference; ggml_from_float_t from_float_ref;
ggml_from_float_to_mat_t from_float_to_mat;
ggml_vec_dot_t vec_dot; ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type; enum ggml_type vec_dot_type;
int64_t nrows; // number of rows to process simultaneously; int64_t nrows; // number of rows to process simultaneously
int64_t ncols; // number of columns to process simultaneously; int64_t ncols; // number of columns to process simultaneously
int64_t interleave_blcksize; // interleave elements in blocks of interleave_blcksize;
ggml_from_float_to_mat_t from_float_to_mat;
ggml_gemv_t gemv; ggml_gemv_t gemv;
ggml_gemm_t gemm; ggml_gemm_t gemm;
} ggml_type_traits_t; } ggml_type_traits_t;

View file

@ -20,19 +20,19 @@
// Functions to create the interleaved data layout formats // Functions to create the interleaved data layout formats
// interleave 4 block_q4_0s in blocks of interleave_blcksize // interleave 4 block_q4_0s in blocks of blck_size_interleave
// returns an interleaved block_q4_0x4 // returns an interleaved block_q4_0x4
// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks // in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
// first, then interleave quants from 4 block_q4_0s in blocks of interleave_blcksize // first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
// //
// - in : an array of block_q4_0 pointers // - in : an array of block_q4_0 pointers
// - interleave_blcksize : the block_q4_0 quants bytes are interleaved in blocks of // - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
// interleave_blcksize bytes // blck_size_interleave bytes
// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes // - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes
// from bias offset form to pure sign form (this saves subtract // from bias offset form to pure sign form (this saves subtract
// operations durin unpacking) // operations durin unpacking)
// //
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
block_q4_0x4 out; block_q4_0x4 out;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -40,9 +40,9 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b
} }
for (int i = 0; i < QK4_0 * 2; i++) { for (int i = 0; i < QK4_0 * 2; i++) {
int src_offset = (i / (4 * interleave_blcksize)) * interleave_blcksize; int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (i % (4 * interleave_blcksize)) / interleave_blcksize; int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (i % interleave_blcksize); src_offset += (i % blck_size_interleave);
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
} }
@ -50,11 +50,11 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b
return out; return out;
} }
// interleave 8 block_q4_0s in blocks of interleave_blcksize // interleave 8 block_q4_0s in blocks of blck_size_interleave
// returns an interleaved block_q4_0x8 // returns an interleaved block_q4_0x8
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
// first, then interleave quants from 8 block_q4_0s in blocks of interleave_blcksize // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) { static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
block_q4_0x8 out; block_q4_0x8 out;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
@ -62,9 +62,9 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_b
} }
for (int i = 0; i < QK4_0 * 4; i++) { for (int i = 0; i < QK4_0 * 4; i++) {
int src_offset = (i / (8 * interleave_blcksize)) * interleave_blcksize; int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
int src_id = (i % (8 * interleave_blcksize)) / interleave_blcksize; int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
src_offset += (i % interleave_blcksize); src_offset += (i % blck_size_interleave);
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
} }
@ -135,7 +135,7 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k)
} }
#else #else
// scalar // scalar
const int interleave_blcksize = 4; const int blck_size_interleave = 4;
float srcv[4][QK8_0]; float srcv[4][QK8_0];
float id[4]; float id[4];
@ -155,12 +155,12 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k)
} }
for (int j = 0; j < QK8_0 * 4; j++) { for (int j = 0; j < QK8_0 * 4; j++) {
int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (j % interleave_blcksize); src_offset += (j % blck_size_interleave);
float x0 = srcv[src_id][src_offset] * id[src_id]; float x0 = srcv[src_id][src_offset] * id[src_id];
y[i].qs[j] = roundf(x0);; y[i].qs[j] = roundf(x0);
} }
} }
#endif #endif
@ -253,7 +253,7 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k)
} }
#else #else
// scalar // scalar
const int interleave_blcksize = 8; const int blck_size_interleave = 8;
float srcv[4][QK8_0]; float srcv[4][QK8_0];
float id[4]; float id[4];
@ -273,26 +273,30 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k)
} }
for (int j = 0; j < QK8_0 * 4; j++) { for (int j = 0; j < QK8_0 * 4; j++) {
int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize; int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize; int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (j % interleave_blcksize); src_offset += (j % blck_size_interleave);
float x0 = srcv[src_id][src_offset] * id[src_id]; float x0 = srcv[src_id][src_offset] * id[src_id];
y[i].qs[j] = roundf(x0);; y[i].qs[j] = roundf(x0);
} }
} }
#endif #endif
} }
void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t interleave_blcksize) { void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
assert(nrow == 4); assert(nrow == 4);
UNUSED(nrow); UNUSED(nrow);
if (interleave_blcksize == 4) quantize_q8_0_4x4(x, vy, n_per_row); if (blck_size_interleave == 4) {
else if (interleave_blcksize == 8) quantize_q8_0_4x8(x, vy, n_per_row); quantize_q8_0_4x4(x, vy, n_per_row);
else assert(false); } else if (blck_size_interleave == 8) {
quantize_q8_0_4x8(x, vy, n_per_row);
} else {
assert(false);
}
} }
static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int interleave_blcksize) { static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
assert(n_per_row % QK4_0 == 0); assert(n_per_row % QK4_0 == 0);
const int nb = n_per_row / QK4_0; const int nb = n_per_row / QK4_0;
@ -311,15 +315,15 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
for (int64_t x = 0; x < nb; x++) { for (int64_t x = 0; x < nb; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) { for (int i = 0; i < nrows_interleaved; i++ ) {
quantize_row_q4_0_reference(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
} }
if (nrows_interleaved == 8) { if (nrows_interleaved == 8) {
*(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, interleave_blcksize, 0x88); *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88);
out_ptr = (block_q4_0x8 *) out_ptr + 1; out_ptr = (block_q4_0x8 *) out_ptr + 1;
} }
else if (nrows_interleaved == 4) { else if (nrows_interleaved == 4) {
*(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, interleave_blcksize, 0x88); *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88);
out_ptr = (block_q4_0x4 *) out_ptr + 1; out_ptr = (block_q4_0x4 *) out_ptr + 1;
} }
} }

View file

@ -16,7 +16,7 @@ extern "C" {
void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t interleave_blcksize); void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);

View file

@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
// backend registry // backend registry
#define GGML_REG_MAX_BACKENDS 16 #define GGML_REG_MAX_BACKENDS 64
struct ggml_backend_reg { struct ggml_backend_reg {
char name[128]; char name[128];

View file

@ -8,11 +8,12 @@
# include <Accelerate/Accelerate.h> # include <Accelerate/Accelerate.h>
#elif defined(GGML_BLAS_USE_MKL) #elif defined(GGML_BLAS_USE_MKL)
# include <mkl.h> # include <mkl.h>
#elif defined(GGML_BLAS_USE_BLIS)
# include <blis.h>
#elif defined(GGML_BLAS_USE_NVPL)
# include <nvpl_blas.h>
#else #else
# include <cblas.h> # include <cblas.h>
# ifdef BLIS_ENABLE_CBLAS
# include <blis.h>
# endif
#endif #endif
struct ggml_backend_blas_context { struct ggml_backend_blas_context {
@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
openblas_set_num_threads(ctx->n_threads); openblas_set_num_threads(ctx->n_threads);
#endif #endif
#if defined(BLIS_ENABLE_CBLAS) #if defined(GGML_BLAS_USE_BLIS)
bli_thread_set_num_threads(ctx->n_threads); bli_thread_set_num_threads(ctx->n_threads);
#endif #endif
#if defined(GGML_BLAS_USE_NVPL)
nvpl_blas_set_num_threads(ctx->n_threads);
#endif
for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) { for (int64_t i12 = 0; i12 < ne12; i12++) {
const int64_t i03 = i13/r3; const int64_t i03 = i13/r3;

View file

@ -104,7 +104,7 @@
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t #define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess #define cudaSuccess hipSuccess
#define __trap abort #define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED

View file

@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
} }
#endif // defined(INT8_MMA_AVAILABLE) #endif // defined(INT8_MMA_AVAILABLE)
} }
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
}
}; };
struct mma_int_B_J8K4 { struct mma_int_B_J8K4 {

File diff suppressed because it is too large Load diff

View file

@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
reinterpret_cast<half&>(y[ib].ds.y) = sum; reinterpret_cast<half&>(y[ib].ds.y) = sum;
} }
template <bool need_sum> template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1( static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
if (ix0 >= kx0_padded) { if (ix0 >= kx0_padded) {
return; return;
} }
const float4 * x4 = (const float4 *) x;
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
block_q8_1_mmq * y = (block_q8_1_mmq *) vy; block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f; // Load 4 floats per thread and calculate max. abs. value between them:
float amax = fabsf(xi); const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
float amax = fabsf(xi.x);
amax = fmaxf(amax, fabsf(xi.y));
amax = fmaxf(amax, fabsf(xi.z));
amax = fmaxf(amax, fabsf(xi.w));
amax = warp_reduce_max(amax); // Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
float sum; for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
if (need_sum) { amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
sum = warp_reduce_sum(xi);
} }
const float d = amax / 127; float sum;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
sum = xi.x + xi.y + xi.z + xi.w;
y[ib].qs[iqs] = q; // Exchange calculate sum across vals_per_sum/4 threads.
#pragma unroll
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
}
}
if (iqs % QK8_1 != 0) { const float d_inv = 127.0f / amax;
char4 q;
q.x = roundf(xi.x*d_inv);
q.y = roundf(xi.y*d_inv);
q.z = roundf(xi.z*d_inv);
q.w = roundf(xi.w*d_inv);
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
char4 * yqs4 = (char4 *) y[ib].qs;
yqs4[iqs/4] = q;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
if (iqs % 16 != 0 || iqs >= 96) {
return; return;
} }
if (need_sum) { y[ib].d2s6[2 + iqs/16] = sum;
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
if (iqs % 64 != 0) {
return;
}
const float d = 1.0f / d_inv;
y[ib].d2s6[iqs/64] = d;
return;
}
if (iqs % 32 != 0) {
return;
}
const float d = 1.0f / d_inv;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
y[ib].ds4[iqs/32] = make_half2(d, sum);
} else { } else {
((float *) y[ib].ds)[iqs/QK8_1] = d; y[ib].d4[iqs/32] = d;
} }
} }
@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, kx1, channels); const dim3 num_blocks(block_num_x, kx1, channels);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
if (mmq_need_sum(type_x)) { switch (mmq_get_q8_1_ds_layout(type_x)) {
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); case MMQ_Q8_1_DS_LAYOUT_D4:
} else { quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
default:
GGML_ASSERT(false);
break;
} }
} }

View file

@ -6,6 +6,10 @@
#include <cstdint> #include <cstdint>
#define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
typedef void (*quantize_cuda_t)( typedef void (*quantize_cuda_t)(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,

View file

@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
} }
#define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 2 #define VDR_Q2_K_Q8_1_MMQ 4
// contiguous v/x values // contiguous v/x values
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
return dm2f.x*sumf_d - dm2f.y*sumf_m; return dm2f.x*sumf_d - dm2f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
template <int ns8>
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) { const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
float sumf_d = 0.0f; float sumf = 0.0f;
float sumf_m = 0.0f; float sumf_d8 = 0.0f;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
int sumi_d = 0; int sumi_d0 = 0;
int sumi_m = 0;
const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
int sumi_d1 = 0;
const int vi0 = v[i0/(QI8_1/2)];
#pragma unroll #pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) { for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product }
sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m); sumf_d8 += dm2f0.x * sumi_d0;
#pragma unroll
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
}
sumf_d8 += dm2f1.x * sumi_d1;
if (i0/QI8_1 < ns8) {
const float2 s8f = __half22float2(s8[i0/QI8_1]);
sumf -= dm2f0.y*s8f.x;
sumf -= dm2f1.y*s8f.y;
} else {
int sumi_m0 = 0;
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
}
sumf_d8 -= dm2f0.y * sumi_m0;
int sumi_m1 = 0;
#pragma unroll
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
}
sumf_d8 -= dm2f1.y * sumi_m1;
}
} }
sumf_d += dm2f.x * sumi_d; return sumf + d8*sumf_d8;
sumf_m += dm2f.y * sumi_m;
}
return d8*(sumf_d - sumf_m);
} }
#define VDR_Q3_K_Q8_1_MMVQ 1 #define VDR_Q3_K_Q8_1_MMVQ 1
@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
return d3 * sumf; return d3 * sumf;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
const float & d3, const float & d8) { const float & d3, const float & d8) {
@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
#pragma unroll #pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) { for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404); sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
} }
sumi += sumi_sc * scales[i0 / (QI8_1/2)]; sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m; return dm4f.x*sumf_d - dm4f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
return dm5f.x*sumf_d - dm5f.y*sumf_m; return dm5f.x*sumf_d - dm5f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
return d*sumf; return d*sumf;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
const float & d6, const float * __restrict__ d8) { const float & d6, const float * __restrict__ d8) {
float sumf_d = 0.0f; float sumf_d = 0.0f;
const int sc_packed = get_int_b4(sc, 0);
const int8_t * sc_reg = (const int8_t *) &sc_packed;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
} }
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
} }
return d6 * sumf_d; return d6 * sumf_d;

View file

@ -193,16 +193,16 @@ enum ggml_metal_kernel_type {
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SQR,
GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
switch (op->src[0]->type) { switch (op->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
switch (op->type) { switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
} }
case GGML_TYPE_F16: case GGML_TYPE_F16:
switch (op->type) { switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16:
return true; return true;
default: default:
return false; return false;
@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1; return op->ne[3] == 1;
} }
default: default:
return false; return false;
@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
switch (dstt) { switch (dstt) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
default: GGML_ASSERT(false && "not implemented"); default: GGML_ASSERT(false && "not implemented");
}; };
} break; } break;

View file

@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
} }
#define N_F32_F32 4 #define N_MV_T_T 4
void kernel_mul_mv_f32_f32_impl( template<typename T0, typename T04, typename T1, typename T14>
void kernel_mul_mv_impl(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1243,9 +1244,8 @@ void kernel_mul_mv_f32_f32_impl(
uint r3, uint r3,
uint3 tgpig, uint3 tgpig,
uint tiisg) { uint tiisg) {
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32; const int64_t rb = tgpig.y*N_MV_T_T;
const int64_t im = tgpig.z; const int64_t im = tgpig.z;
const uint i12 = im%ne12; const uint i12 = im%ne12;
@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const float * x = (device const float *) (src0 + offset0); device const T0 * x = (device const T0 *) (src0 + offset0);
if (ne00 < 128) { if (ne00 < 128) {
for (int row = 0; row < N_F32_F32; ++row) { for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row; int r1 = rb + row;
if (r1 >= ne11) { if (r1 >= ne11) {
break; break;
} }
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0; float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) { for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i]; sumf += (T0) x[i] * (T1) y[i];
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
} }
} }
} else { } else {
device const float4 * x4 = (device const float4 *)x; device const T04 * x4 = (device const T04 *) x;
for (int row = 0; row < N_F32_F32; ++row) { for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row; int r1 = rb + row;
if (r1 >= ne11) { if (r1 >= ne11) {
break; break;
} }
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y; device const T14 * y4 = (device const T14 *) y;
float sumf = 0; float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) { for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
} }
} }
} }
} }
[[host_name("kernel_mul_mv_f32_f32")]] template<typename T0, typename T04, typename T1, typename T14>
kernel void kernel_mul_mv_f32_f32( kernel void kernel_mul_mv(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
constant uint & r3, constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); kernel_mul_mv_impl<T0, T04, T1, T14>(
src0,
src1,
dst,
ne00,
ne01,
ne02,
nb00,
nb01,
nb02,
ne10,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3,
tgpig,
tiisg);
} }
#define N_F16_F16 4 typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
kernel void kernel_mul_mv_f16_f16( template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
device const char * src0, template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
device const char * src1, template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x; template<typename T, typename T4>
const int64_t rb = tgpig.y*N_F16_F16; kernel void kernel_mul_mv_1row(
const int64_t im = tgpig.z;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half * x = (device const half *) (src0 + offset0);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (half) x[i] * (half) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
device const half4 * y4 = (device const half4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}
void kernel_mul_mv_f16_f32_1row_impl(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half * x = (device const half *) (src0 + offset0); device const T * x = (device const T *) (src0 + offset0);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0; float sumf = 0;
@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
} }
} else { } else {
device const half4 * x4 = (device const half4 *) x; device const T4 * x4 = (device const T4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
[[host_name("kernel_mul_mv_f16_f32_1row")]]
kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#define N_F16_F32 4
void kernel_mul_mv_f16_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
uint64_t nb00,
uint64_t nb01,
uint64_t nb02,
int64_t ne10,
int64_t ne11,
int64_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
int64_t ne0,
int64_t ne1,
uint r2,
uint r3,
uint3 tgpig,
uint tiisg) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F32;
const int64_t im = tgpig.z;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half * x = (device const half *) (src0 + offset0);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y; device const float4 * y4 = (device const float4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) { for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
} }
} }
} }
}
[[host_name("kernel_mul_mv_f16_f32")]] typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
kernel void kernel_mul_mv_f16_f32(
device const char * src0, template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
// Assumes row size (ne00) is a multiple of 4 // Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mv_f16_f32_l4( template<typename T, typename T4>
kernel void kernel_mul_mv_l4(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half4 * x4 = (device const half4 *) (src0 + offset0); device const T4 * x4 = (device const T4 *) (src0 + offset0);
for (int r1 = 0; r1 < nrows; ++r1) { for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0; float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) { for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
} }
} }
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
static float rope_yarn_ramp(const float low, const float high, const int i0) { static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low); const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y)); return 1.0f - min(1.0f, max(0.0f, y));
@ -2765,9 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
kernel void kernel_cpy_f16_f16( template<typename T0, typename T1>
device const half * src0, kernel void kernel_cpy(
device half * dst, device const void * src0,
device void * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
@ -2798,138 +2627,20 @@ kernel void kernel_cpy_f16_f16(
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0]; dst_data[i00] = (T1) src[0];
} }
} }
kernel void kernel_cpy_f16_f32( typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
device const half * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
const int64_t i3 = n / (ne2*ne1*ne0); template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_f32(
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_q8_0( kernel void kernel_cpy_f32_q8_0(
device const float * src0, device const float * src0,
@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
} }
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)> template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows( kernel void kernel_get_rows_q(
device const void * src0, device const void * src0,
device const char * src1, device const void * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
@ -5745,27 +5456,24 @@ kernel void kernel_get_rows(
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) { uint3 tptg [[threads_per_threadgroup]]) {
//const int64_t i = tgpig;
//const int64_t r = ((device int32_t *) src1)[i];
const int64_t i10 = tgpig.x; const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y; const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11; const int64_t i02 = i11;
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp; float4x4 temp;
dequantize_func( dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
} }
} }
kernel void kernel_get_rows_f32( template<typename T>
kernel void kernel_get_rows_f(
device const void * src0, device const void * src0,
device const char * src1, device const void * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
@ -5781,47 +5489,19 @@ kernel void kernel_get_rows_f32(
const int64_t i10 = tgpig.x; const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y; const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11; const int64_t i02 = i11;
for (int ind = tiitg; ind < ne00; ind += tptg.x) { for (int ind = tiitg; ind < ne00; ind += tptg.x) {
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
}
}
kernel void kernel_get_rows_f16(
device const void * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1,
constant uint64_t & nb2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11;
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
} }
} }
kernel void kernel_get_rows_i32( kernel void kernel_get_rows_i32(
device const void * src0, device const void * src0,
device const char * src1, device const void * src1,
device int32_t * dst, device int32_t * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant uint64_t & nb01, constant uint64_t & nb01,
@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
const int64_t i10 = tgpig.x; const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y; const int64_t i11 = tgpig.y;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
const int64_t i02 = i11; const int64_t i02 = i11;
for (int ind = tiitg; ind < ne00; ind += tptg.x) { for (int ind = tiitg; ind < ne00; ind += tptg.x) {
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
} }
} }
@ -5860,8 +5540,8 @@ kernel void kernel_get_rows_i32(
#define SG_MAT_ROW 8 #define SG_MAT_ROW 8
// each block_q contains 16*nl weights // each block_q contains 16*nl weights
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
void kernel_mul_mm_impl(device const uchar * src0, kernel void kernel_mul_mm(device const uchar * src0,
device const uchar * src1, device const uchar * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
@ -5881,7 +5561,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup T * sa = (threadgroup T *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y; const uint r0 = tgpig.y;
@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_half8x8 ma[4]; simdgroup_T8x8 ma[4];
simdgroup_float8x8 mb[2]; simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8]; simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){ for (int i = 0; i < 8; i++){
@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory // load data and store to threadgroup memory
half4x4 temp_a; T4x4 temp_a;
dequantize_func(x, il, temp_a); dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products // load matrices from threadgroup memory and conduct outer products
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
#pragma unroll(4) #pragma unroll(4)
@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
} }
} }
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm(device const uchar * src0,
device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
src0,
src1,
dst,
ne00,
ne02,
nb01,
nb02,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3,
shared_memory,
tgpig,
tiitg,
sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id( kernel void kernel_mul_mm_id(
device const uchar * src0s, device const uchar * src0s,
@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
// get rows // get rows
// //
typedef void (get_rows_t)( typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
device const void * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1,
constant uint64_t & nb2,
uint3, uint, uint3);
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>; template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>; template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>; template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>; template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>; template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
// //
// matrix-matrix multiplication // matrix-matrix multiplication
// //
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t; typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>; template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>; template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>; template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>; template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
// //
// indirect matrix-matrix multiplication // indirect matrix-matrix multiplication
@ -6436,7 +6065,7 @@ void mmv_fn(
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
} }
typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t; typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn> template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id( kernel void kernel_mul_mv_id(
@ -6514,10 +6143,10 @@ kernel void kernel_mul_mv_id(
sgitg); sgitg);
} }
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t; typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>; template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>; template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;

View file

@ -659,7 +659,7 @@ static inline __m128i packNibbles( __m256i bytes ) {
#endif //__loongarch_asx #endif //__loongarch_asx
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0; static const int qk = QK4_0;
assert(k % qk == 0); assert(k % qk == 0);
@ -697,11 +697,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict
} }
void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q4_0_reference(x, y, k); quantize_row_q4_0_ref(x, y, k);
} }
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
const int qk = QK4_1; const int qk = QK4_1;
assert(k % qk == 0); assert(k % qk == 0);
@ -739,10 +739,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict
} }
void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q4_1_reference(x, y, k); quantize_row_q4_1_ref(x, y, k);
} }
void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
static const int qk = QK5_0; static const int qk = QK5_0;
assert(k % qk == 0); assert(k % qk == 0);
@ -787,10 +787,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict
} }
void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q5_0_reference(x, y, k); quantize_row_q5_0_ref(x, y, k);
} }
void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
const int qk = QK5_1; const int qk = QK5_1;
assert(k % qk == 0); assert(k % qk == 0);
@ -835,11 +835,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict
} }
void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q5_1_reference(x, y, k); quantize_row_q5_1_ref(x, y, k);
} }
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
assert(k % QK8_0 == 0); assert(k % QK8_0 == 0);
const int nb = k / QK8_0; const int nb = k / QK8_0;
@ -1145,12 +1145,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
#else #else
GGML_UNUSED(nb); GGML_UNUSED(nb);
// scalar // scalar
quantize_row_q8_0_reference(x, y, k); quantize_row_q8_0_ref(x, y, k);
#endif #endif
} }
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
assert(QK8_1 == 32); assert(QK8_1 == 32);
assert(k % QK8_1 == 0); assert(k % QK8_1 == 0);
const int nb = k / QK8_1; const int nb = k / QK8_1;
@ -1509,7 +1509,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#else #else
GGML_UNUSED(nb); GGML_UNUSED(nb);
// scalar // scalar
quantize_row_q8_1_reference(x, y, k); quantize_row_q8_1_ref(x, y, k);
#endif #endif
} }
@ -1900,7 +1900,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
//========================- 2-bit (de)-quantization //========================- 2-bit (de)-quantization
void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
@ -2003,7 +2003,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6
} }
void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
quantize_row_q2_K_reference(x, vy, k); quantize_row_q2_K_ref(x, vy, k);
} }
static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
@ -2227,7 +2227,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -2242,7 +2242,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr
//========================= 3-bit (de)-quantization //========================= 3-bit (de)-quantization
void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
@ -2369,7 +2369,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6
} }
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
quantize_row_q3_K_reference(x, vy, k); quantize_row_q3_K_ref(x, vy, k);
} }
static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
@ -2459,7 +2459,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -2474,7 +2474,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr
// ====================== 4-bit (de)-quantization // ====================== 4-bit (de)-quantization
void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int nb = k / QK_K; const int nb = k / QK_K;
@ -2573,7 +2573,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6
void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_q4_K * restrict y = vy; block_q4_K * restrict y = vy;
quantize_row_q4_K_reference(x, y, k); quantize_row_q4_K_ref(x, y, k);
} }
static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@ -2652,7 +2652,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -2667,7 +2667,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr
// ====================== 5-bit (de)-quantization // ====================== 5-bit (de)-quantization
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int64_t nb = k / QK_K; const int64_t nb = k / QK_K;
@ -2784,7 +2784,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6
void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_q5_K * restrict y = vy; block_q5_K * restrict y = vy;
quantize_row_q5_K_reference(x, y, k); quantize_row_q5_K_ref(x, y, k);
} }
static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@ -2883,7 +2883,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -2898,7 +2898,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr
// ====================== 6-bit (de)-quantization // ====================== 6-bit (de)-quantization
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int64_t nb = k / QK_K; const int64_t nb = k / QK_K;
@ -3002,7 +3002,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6
void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_q6_K * restrict y = vy; block_q6_K * restrict y = vy;
quantize_row_q6_K_reference(x, y, k); quantize_row_q6_K_ref(x, y, k);
} }
static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@ -3092,7 +3092,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
if (!quant_weights) { if (!quant_weights) {
quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
} }
else { else {
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -3109,7 +3109,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
static_assert(QK4_0 == 32, "QK4_0 must be 32"); static_assert(QK4_0 == 32, "QK4_0 must be 32");
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_0_reference(x, y, n_per_row); quantize_row_q4_0_ref(x, y, n_per_row);
return; return;
} }
@ -3135,7 +3135,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
@ -3152,7 +3152,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
static_assert(QK4_1 == 32, "QK4_1 must be 32"); static_assert(QK4_1 == 32, "QK4_1 must be 32");
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_1_reference(x, y, n_per_row); quantize_row_q4_1_ref(x, y, n_per_row);
return; return;
} }
@ -3180,7 +3180,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
@ -3197,7 +3197,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
static_assert(QK5_0 == 32, "QK5_0 must be 32"); static_assert(QK5_0 == 32, "QK5_0 must be 32");
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_0_reference(x, y, n_per_row); quantize_row_q5_0_ref(x, y, n_per_row);
return; return;
} }
@ -3234,7 +3234,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
@ -3251,7 +3251,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
static_assert(QK5_1 == 32, "QK5_1 must be 32"); static_assert(QK5_1 == 32, "QK5_1 must be 32");
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_1_reference(x, y, n_per_row); quantize_row_q5_1_ref(x, y, n_per_row);
return; return;
} }
@ -3287,7 +3287,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
@ -3303,7 +3303,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr
size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used (void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * row_size; return nrow * row_size;
} }
@ -3591,7 +3591,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
//===================================== Q8_K ============================================== //===================================== Q8_K ==============================================
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int64_t nb = k / QK_K; const int64_t nb = k / QK_K;
@ -3642,7 +3642,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
} }
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q8_K_reference(x, y, k); quantize_row_q8_K_ref(x, y, k);
} }
//===================================== Dot ptoducts ================================= //===================================== Dot ptoducts =================================
@ -13531,10 +13531,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t
void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_iq3_xxs * restrict y = vy; block_iq3_xxs * restrict y = vy;
quantize_row_iq3_xxs_reference(x, y, k); quantize_row_iq3_xxs_ref(x, y, k);
} }
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_row_iq3_xxs_impl(256, x, y, k, NULL); quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
} }
@ -13747,10 +13747,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n
void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_iq3_s * restrict y = vy; block_iq3_s * restrict y = vy;
quantize_row_iq3_s_reference(x, y, k); quantize_row_iq3_s_ref(x, y, k);
} }
void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_iq3_s(x, y, 1, k, NULL); quantize_iq3_s(x, y, 1, k, NULL);
} }
@ -14488,7 +14488,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k
} }
} }
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
assert(k % QK4_NL == 0); assert(k % QK4_NL == 0);
quantize_row_iq4_nl(x, y, k); quantize_row_iq4_nl(x, y, k);
} }
@ -14516,10 +14516,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_iq4_xs * restrict y = vy; block_iq4_xs * restrict y = vy;
quantize_row_iq4_xs_reference(x, y, k); quantize_row_iq4_xs_ref(x, y, k);
} }
void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_iq4_xs(x, y, 1, k, NULL); quantize_iq4_xs(x, y, 1, k, NULL);
} }
@ -14706,7 +14706,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n
return nrow * nblock * sizeof(block_iq2_s); return nrow * nblock * sizeof(block_iq2_s);
} }
void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_iq2_s(x, y, 1, k, NULL); quantize_iq2_s(x, y, 1, k, NULL);
} }
@ -14714,7 +14714,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri
void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
block_iq2_s * restrict y = vy; block_iq2_s * restrict y = vy;
quantize_row_iq2_s_reference(x, y, k); quantize_row_iq2_s_ref(x, y, k);
} }
static bool validate_float(float f, size_t i) { static bool validate_float(float f, size_t i) {

View file

@ -12,25 +12,25 @@ extern "C" {
#endif #endif
// Quantization // Quantization
void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

View file

@ -3768,37 +3768,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
const ggml_tensor_extra_gpu *src0_extra =
(const ggml_tensor_extra_gpu *)src0->extra;
const ggml_tensor_extra_gpu *src1_extra =
(const ggml_tensor_extra_gpu *)src1->extra;
const ggml_tensor_extra_gpu *dst_extra =
(const ggml_tensor_extra_gpu *)dst->extra;
ggml_tensor_extra_gpu src0_row_extra;
ggml_tensor_extra_gpu src1_row_extra;
ggml_tensor_extra_gpu dst_row_extra;
ggml_tensor src0_row = *src0; ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1; ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst; ggml_tensor dst_row = *dst;
src1_row.backend = GGML_BACKEND_TYPE_GPU; char *src0_original = (char *)src0->data;
dst_row.backend = GGML_BACKEND_TYPE_GPU; char *src1_original = (char *)src1->data;
char *dst_original = (char *)dst->data;
src0_row.extra = &src0_row_extra;
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
? (char *)src0->data
: (char *)src0_extra->data_device[ctx.device];
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
? (char *)src1->data
: (char *)src1_extra->data_device[ctx.device];
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
? (char *)dst->data
: (char *)dst_extra->data_device[ctx.device];
src0_row.ne[2] = 1; src0_row.ne[2] = 1;
src0_row.ne[3] = 1; src0_row.ne[3] = 1;
@ -3827,12 +3803,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
const int64_t i1 = id; const int64_t i1 = id;
const int64_t i2 = i12; const int64_t i2 = i12;
src0_row_extra.data_device[ctx.device] = src0_row.data = src0_original + i02*nb02;
src0_original + i02*nb02; src1_row.data = src1_original + + i11*nb11 + i12*nb12;
src1_row_extra.data_device[ctx.device] = dst_row.data = dst_original + i1*nb1 + i2*nb2;
src1_original + + i11*nb11 + i12*nb12;
dst_row_extra.data_device[ctx.device] =
dst_original + i1*nb1 + i2*nb2;
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
} }
@ -3841,8 +3814,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
src1_row_extra.data_device[ctx.device] = src1_contiguous.get(); src1_row.data = src1_contiguous.get();
dst_row_extra.data_device[ctx.device] = dst_contiguous.get(); dst_row.data = dst_contiguous.get();
for (int64_t i02 = 0; i02 < n_as; i02++) { for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0; int64_t num_src1_rows = 0;
@ -3898,7 +3871,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
}); });
} }
src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02; src0_row.data = src0_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0); GGML_ASSERT(nb1 == sizeof(float)*ne0);
@ -5221,6 +5194,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
return false; return false;
} }
} }
ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16) {
return false;
}
return true; return true;
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:

View file

@ -600,7 +600,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
.from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
.vec_dot_type = GGML_TYPE_F16, .vec_dot_type = GGML_TYPE_F16,
.nrows = 1, .nrows = 1,
@ -612,7 +612,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_0, .to_float = (ggml_to_float_t) dequantize_row_q4_0,
.from_float = quantize_row_q4_0, .from_float = quantize_row_q4_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
.vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot = ggml_vec_dot_q4_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -628,7 +628,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_1, .to_float = (ggml_to_float_t) dequantize_row_q4_1,
.from_float = quantize_row_q4_1, .from_float = quantize_row_q4_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
.vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot = ggml_vec_dot_q4_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -644,7 +644,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT, .vec_dot_type = GGML_TYPE_COUNT,
.nrows = 1, .nrows = 1,
@ -656,7 +656,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT, .vec_dot_type = GGML_TYPE_COUNT,
.nrows = 1, .nrows = 1,
@ -668,7 +668,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_0, .to_float = (ggml_to_float_t) dequantize_row_q5_0,
.from_float = quantize_row_q5_0, .from_float = quantize_row_q5_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
.vec_dot = ggml_vec_dot_q5_0_q8_0, .vec_dot = ggml_vec_dot_q5_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
@ -680,7 +680,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_1, .to_float = (ggml_to_float_t) dequantize_row_q5_1,
.from_float = quantize_row_q5_1, .from_float = quantize_row_q5_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
.vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot = ggml_vec_dot_q5_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1, .nrows = 1,
@ -692,7 +692,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q8_0, .to_float = (ggml_to_float_t) dequantize_row_q8_0,
.from_float = quantize_row_q8_0, .from_float = quantize_row_q8_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8) #if defined (__ARM_FEATURE_MATMUL_INT8)
@ -700,7 +701,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
#else #else
.nrows = 1, .nrows = 1,
#endif #endif
.from_float_to_mat = quantize_mat_q8_0,
}, },
[GGML_TYPE_Q8_1] = { [GGML_TYPE_Q8_1] = {
.type_name = "q8_1", .type_name = "q8_1",
@ -708,7 +708,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_q8_1), .type_size = sizeof(block_q8_1),
.is_quantized = true, .is_quantized = true,
.from_float = quantize_row_q8_1, .from_float = quantize_row_q8_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1, .nrows = 1,
}, },
@ -719,7 +719,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q2_K, .to_float = (ggml_to_float_t) dequantize_row_q2_K,
.from_float = quantize_row_q2_K, .from_float = quantize_row_q2_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref,
.vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot = ggml_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -731,7 +731,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q3_K, .to_float = (ggml_to_float_t) dequantize_row_q3_K,
.from_float = quantize_row_q3_K, .from_float = quantize_row_q3_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref,
.vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot = ggml_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -743,7 +743,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_K, .to_float = (ggml_to_float_t) dequantize_row_q4_K,
.from_float = quantize_row_q4_K, .from_float = quantize_row_q4_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
.vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot = ggml_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -755,7 +755,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q5_K, .to_float = (ggml_to_float_t) dequantize_row_q5_K,
.from_float = quantize_row_q5_K, .from_float = quantize_row_q5_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
.vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot = ggml_vec_dot_q5_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -767,7 +767,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q6_K, .to_float = (ggml_to_float_t) dequantize_row_q6_K,
.from_float = quantize_row_q6_K, .from_float = quantize_row_q6_K,
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
.vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -779,7 +779,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs, .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -791,7 +791,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_xs, .to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = ggml_vec_dot_iq2_xs_q8_K, .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -803,7 +803,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
.from_float = quantize_row_iq3_xxs, .from_float = quantize_row_iq3_xxs,
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -815,7 +815,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq3_s, .to_float = (ggml_to_float_t) dequantize_row_iq3_s,
.from_float = quantize_row_iq3_s, .from_float = quantize_row_iq3_s,
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
.vec_dot = ggml_vec_dot_iq3_s_q8_K, .vec_dot = ggml_vec_dot_iq3_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -827,7 +827,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_s, .to_float = (ggml_to_float_t) dequantize_row_iq2_s,
.from_float = quantize_row_iq2_s, .from_float = quantize_row_iq2_s,
.from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref,
.vec_dot = ggml_vec_dot_iq2_s_q8_K, .vec_dot = ggml_vec_dot_iq2_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -839,7 +839,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq1_s, .to_float = (ggml_to_float_t) dequantize_row_iq1_s,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = ggml_vec_dot_iq1_s_q8_K, .vec_dot = ggml_vec_dot_iq1_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -851,7 +851,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq1_m, .to_float = (ggml_to_float_t) dequantize_row_iq1_m,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = ggml_vec_dot_iq1_m_q8_K, .vec_dot = ggml_vec_dot_iq1_m_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -863,7 +863,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq4_nl, .to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
.from_float = quantize_row_iq4_nl, .from_float = quantize_row_iq4_nl,
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot = ggml_vec_dot_iq4_nl_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
@ -875,7 +875,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true, .is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
.from_float = quantize_row_iq4_xs, .from_float = quantize_row_iq4_xs,
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref,
.vec_dot = ggml_vec_dot_iq4_xs_q8_K, .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
@ -894,7 +894,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
.vec_dot_type = GGML_TYPE_BF16, .vec_dot_type = GGML_TYPE_BF16,
.nrows = 1, .nrows = 1,
@ -902,48 +902,48 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0_4_4] = { [GGML_TYPE_Q4_0_4_4] = {
.type_name = "q4_0_4x4", .type_name = "q4_0_4x4",
.blck_size = QK4_0, .blck_size = QK4_0,
.blck_size_interleave = 4,
.type_size = sizeof(block_q4_0), .type_size = sizeof(block_q4_0),
.is_quantized = true, .is_quantized = true,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
.ncols = 4, .ncols = 4,
.interleave_blcksize = 4,
.gemv = ggml_gemv_q4_0_4x4_q8_0, .gemv = ggml_gemv_q4_0_4x4_q8_0,
.gemm = ggml_gemm_q4_0_4x4_q8_0, .gemm = ggml_gemm_q4_0_4x4_q8_0,
}, },
[GGML_TYPE_Q4_0_4_8] = { [GGML_TYPE_Q4_0_4_8] = {
.type_name = "q4_0_4x8", .type_name = "q4_0_4x8",
.blck_size = QK4_0, .blck_size = QK4_0,
.blck_size_interleave = 8,
.type_size = sizeof(block_q4_0), .type_size = sizeof(block_q4_0),
.is_quantized = true, .is_quantized = true,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
.ncols = 4, .ncols = 4,
.interleave_blcksize = 8,
.gemv = ggml_gemv_q4_0_4x8_q8_0, .gemv = ggml_gemv_q4_0_4x8_q8_0,
.gemm = ggml_gemm_q4_0_4x8_q8_0, .gemm = ggml_gemm_q4_0_4x8_q8_0,
}, },
[GGML_TYPE_Q4_0_8_8] = { [GGML_TYPE_Q4_0_8_8] = {
.type_name = "q4_0_8x8", .type_name = "q4_0_8x8",
.blck_size = QK4_0, .blck_size = QK4_0,
.blck_size_interleave = 8,
.type_size = sizeof(block_q4_0), .type_size = sizeof(block_q4_0),
.is_quantized = true, .is_quantized = true,
.to_float = NULL, .to_float = NULL,
.from_float = NULL, .from_float = NULL,
.from_float_reference = NULL, .from_float_ref = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1, .nrows = 1,
.ncols = 8, .ncols = 8,
.interleave_blcksize = 8,
.gemv = ggml_gemv_q4_0_8x8_q8_0, .gemv = ggml_gemv_q4_0_8x8_q8_0,
.gemm = ggml_gemm_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0,
} }
@ -3135,7 +3135,7 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
} }
GGML_CALL int ggml_blck_size(enum ggml_type type) { GGML_CALL int64_t ggml_blck_size(enum ggml_type type) {
return type_traits[type].blck_size; return type_traits[type].blck_size;
} }
@ -12239,12 +12239,11 @@ static void ggml_compute_forward_mul_mat(
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
int64_t const vec_dot_num_rows = type_traits[type].nrows; int64_t const vec_dot_num_rows = type_traits[type].nrows;
int64_t const matmul_num_cols = type_traits[type].ncols; int64_t const matmul_num_cols = type_traits[type].ncols;
int64_t const interleave_blcksize = type_traits[type].interleave_blcksize; int64_t const blck_size_interleave = type_traits[type].blck_size_interleave;
ggml_from_float_to_mat_t const from_float_to_mat
= type_traits[vec_dot_type].from_float_to_mat;
ggml_gemv_t const gemv = type_traits[type].gemv; ggml_gemv_t const gemv = type_traits[type].gemv;
ggml_gemm_t const gemm = type_traits[type].gemm; ggml_gemm_t const gemm = type_traits[type].gemm;
@ -12318,12 +12317,12 @@ UseGgmlGemm1:;
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
4, ne10, interleave_blcksize); 4, ne10, blck_size_interleave);
} }
i11_processed = ne11 - ne11 % 4; i11_processed = ne11 - ne11 % 4;
} }
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10); ne10);
} }
@ -12469,7 +12468,7 @@ static void ggml_compute_forward_mul_mat_id(
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
int64_t const matmul_num_cols = type_traits[type].ncols; int64_t const matmul_num_cols = type_traits[type].ncols;
ggml_gemv_t const gemv = type_traits[type].gemv; ggml_gemv_t const gemv = type_traits[type].gemv;
@ -12512,7 +12511,7 @@ static void ggml_compute_forward_mul_mat_id(
for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10); ne10);
} }
@ -21166,7 +21165,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
(int64_t) info->ne[3]; (int64_t) info->ne[3];
if (ne % ggml_blck_size(info->type) != 0) { if (ne % ggml_blck_size(info->type) != 0) {
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
__func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
fclose(file); fclose(file);
gguf_free(ctx); gguf_free(ctx);

View file

@ -5955,13 +5955,6 @@ static bool llm_load_tensors(
auto & hparams = model.hparams; auto & hparams = model.hparams;
#ifdef GGML_USE_SYCL
// disable MoE with SYCL until mul_mat_id is updated
if (hparams.n_expert > 0) {
n_gpu_layers = 0;
}
#endif
model.split_mode = split_mode; model.split_mode = split_mode;
model.main_gpu = main_gpu; model.main_gpu = main_gpu;
model.n_gpu_layers = n_gpu_layers; model.n_gpu_layers = n_gpu_layers;
@ -21500,7 +21493,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
size--; size--;
} }
if (length < (int32_t)size) { if (length < (int32_t)size) {
return (int32_t) -size; return -(int32_t) size;
} }
memcpy(buf, token, size); memcpy(buf, token, size);
return (int32_t) size; return (int32_t) size;