diff --git a/common/arg.cpp b/common/arg.cpp index 95f2279e4..cf4fe4e46 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -229,12 +229,13 @@ static bool common_download_file_single(const std::string & url, const std::stri curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); // Check if hf-token or bearer-token was specified if (!bearer_token.empty()) { std::string auth_header = "Authorization: Bearer " + bearer_token; http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); } + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); #if defined(_WIN32) // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of @@ -545,7 +546,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); curl_slist_ptr http_headers; std::string res_str; - std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; + + std::string model_endpoint = get_model_endpoint(); + + std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag; curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); @@ -660,13 +664,8 @@ static void common_params_handle_model( } } - std::string hf_endpoint = "https://huggingface.co/"; - const char * hf_endpoint_env = getenv("HF_ENDPOINT"); - if (hf_endpoint_env) { - hf_endpoint = hf_endpoint_env; - if (hf_endpoint.back() != '/') hf_endpoint += '/'; - } - model.url = hf_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; + std::string model_endpoint = get_model_endpoint(); + model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; // make sure model path is present (for caching purposes) if (model.path.empty()) { // this is to avoid different repo having same file name, or same file name in different subdirs diff --git a/common/chat.cpp b/common/chat.cpp index 62ca26ad7..bbc5f087c 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1622,7 +1622,7 @@ static common_chat_params common_chat_templates_apply_jinja( } // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) - if (src.find("") != std::string::npos && params.json_schema.is_null()) { + if (src.find("") != std::string::npos && params.json_schema.is_null() && params.tools.is_array() && params.json_schema.is_null()) { return common_chat_params_init_hermes_2_pro(tmpl, params); } diff --git a/common/common.cpp b/common/common.cpp index a09153a37..e7b7c7ad5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -837,7 +837,7 @@ std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#ifdef __linux__ +#if defined(__linux__) || defined(__FreeBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else { @@ -847,7 +847,9 @@ std::string fs_get_cache_directory() { cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); -#endif // __linux__ +#else +# error Unknown architecture +#endif cache_directory = ensure_trailing_slash(cache_directory); cache_directory += "llama.cpp"; } @@ -1034,6 +1036,19 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } +std::string get_model_endpoint() { + const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); + // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. + const char * hf_endpoint_env = getenv("HF_ENDPOINT"); + const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env; + std::string model_endpoint = "https://huggingface.co/"; + if (endpoint_env) { + model_endpoint = endpoint_env; + if (model_endpoint.back() != '/') model_endpoint += '/'; + } + return model_endpoint; +} + void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { llama_clear_adapter_lora(ctx); for (auto & la : lora) { diff --git a/common/common.h b/common/common.h index e49b81108..4736c0595 100644 --- a/common/common.h +++ b/common/common.h @@ -539,6 +539,8 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); +std::string get_model_endpoint(); + // // Batch utils // diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c9ac2957f..2bf97475f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -735,6 +735,9 @@ class Model: if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": # ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct res = "llama4" + if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": + # ref: https://huggingface.co/THUDM/glm-4-9b-hf + res = "glm4" if res is None: logger.warning("\n") @@ -1750,7 +1753,7 @@ class LlamaModel(Model): low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - assert low_freq_wavelen != high_freq_wavelen + # assert low_freq_wavelen != high_freq_wavelen # Errors for Llama4 rope_factors = [] for freq in freqs: @@ -1806,10 +1809,6 @@ class Llama4Model(LlamaModel): self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - name = name.replace("language_model.", "") - name = name.replace("feed_forward.", "mlp.") # a bit hacky for now - name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now - # split the gate_up into gate and up if "gate_up_proj" in name: name_up = name.replace("gate_up_proj", "up_proj.weight") @@ -4901,6 +4900,22 @@ class JaisModel(Model): self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) +@Model.register("Glm4ForCausalLM") +class Glm4Model(Model): + model_arch = gguf.MODEL_ARCH.GLM4 + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + + @Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -5592,7 +5607,6 @@ def main() -> None: with torch.inference_mode(): output_type = ftype_map[args.outtype] model_architecture = hparams["architectures"][0] - try: model_class = Model.from_model_architecture(model_architecture) except NotImplementedError: diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index ce6104da4..160c9fe0e 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -114,6 +114,7 @@ models = [ {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", }, ] diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index c93d7877c..55b18dc89 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -1,5 +1,8 @@ #include "ggml.h" #include "gguf.h" +#include "clip.h" + +#include "clip.h" #include #include @@ -7,6 +10,7 @@ #include #include #include +#include // Internal header for clip.cpp @@ -125,6 +129,23 @@ static projector_type clip_projector_type_from_string(const std::string & str) { return PROJECTOR_TYPE_UNKNOWN; } +// RGB uint8 image +struct clip_image_u8 { + int nx; + int ny; + + std::vector buf; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + + std::vector buf; +}; + // // logging // @@ -183,6 +204,36 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, .. #define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) #define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__) +// +// cpp wrappers +// + +// wrapper for clip_image_size +struct clip_image_size_deleter { + void operator()(clip_image_size * val) { clip_image_size_free(val); } +}; +typedef std::unique_ptr clip_image_size_ptr; + +// wrapper for clip_image_u8 +struct clip_image_u8_deleter { + void operator()(clip_image_u8 * val) { clip_image_u8_free(val); } +}; +typedef std::unique_ptr clip_image_u8_ptr; + +// wrapper for clip_image_f32 +struct clip_image_f32_deleter { + void operator()(clip_image_f32 * val) { clip_image_f32_free(val); } +}; +typedef std::unique_ptr clip_image_f32_ptr; + +struct clip_image_u8_batch { + std::vector entries; +}; + +struct clip_image_f32_batch { + std::vector entries; +}; + // // common utils // @@ -219,6 +270,20 @@ static void string_replace_all(std::string & s, const std::string & search, cons s = std::move(builder); } +// split string by a `std::string delim` instead of `char delim` +static std::vector string_split_str(std::string s, const std::string & delimiter) { + std::vector tokens; + size_t pos = 0; + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + return tokens; +} + // // gguf utils // @@ -276,3 +341,9 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); } } + +// +// API used internally with mtmd +// + +projector_type clip_get_projector_type(const struct clip_ctx * ctx); diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index c73e3aaf8..5bd8159ce 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -46,23 +46,6 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac //#define CLIP_DEBUG_FUNCTIONS -// RGB uint8 image -struct clip_image_u8 { - int nx; - int ny; - - std::vector buf; -}; - -// RGB float32 image (NHWC) -// Memory layout: RGBRGBRGB... -struct clip_image_f32 { - int nx; - int ny; - - std::vector buf; -}; - #ifdef CLIP_DEBUG_FUNCTIONS static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) { std::ofstream file(filename, std::ios::binary); @@ -360,20 +343,20 @@ struct clip_ctx { bool use_rms_norm = false; int32_t ftype = 1; - struct gguf_context * ctx_gguf = nullptr; - struct ggml_context * ctx_data = nullptr; + gguf_context_ptr ctx_gguf; + ggml_context_ptr ctx_data; std::vector buf_compute_meta; std::vector backend_ptrs; std::vector backend_buft; - ggml_backend_t backend = nullptr; - ggml_backend_buffer_t buf = nullptr; + ggml_backend_ptr backend; + ggml_backend_buffer_ptr buf; ggml_backend_sched_ptr sched; - struct clip_image_size * load_image_size = nullptr; + clip_image_size load_image_size; clip_ctx(clip_context_params & ctx_params) { @@ -405,17 +388,9 @@ struct clip_ctx { ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false) ); } - - ~clip_ctx() { - ggml_free(ctx_data); - gguf_free(ctx_gguf); - ggml_backend_buffer_free(buf); - ggml_backend_free(backend); - clip_image_size_free(load_image_size); - } }; -static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) { +static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) { const auto & model = ctx->vision_model; const auto & hparams = model.hparams; @@ -431,7 +406,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im const int n_layer = hparams.n_layer; const float eps = hparams.eps; - GGML_ASSERT(imgs->size == 1); // batch_size == 1 + GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1 struct ggml_init_params params = { /*.mem_size =*/ ctx->buf_compute_meta.size(), @@ -439,7 +414,9 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im /*.no_alloc =*/ true, }; - struct ggml_context * ctx0 = ggml_init(params); + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); // input raw @@ -561,12 +538,10 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im // build the graph ggml_build_forward_expand(gf, embeddings); - ggml_free(ctx0); - return gf; } -static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { +static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); return nullptr; @@ -579,23 +554,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im int image_size_width = image_size; int image_size_height = image_size; if (ctx->has_minicpmv_projector) { - if (load_image_size == nullptr) { - load_image_size = clip_image_size_init(); - } - LOG_DBG("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height); - image_size_width = load_image_size->width; - image_size_height = load_image_size->height; + LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height); + image_size_width = load_image_size.width; + image_size_height = load_image_size.height; if (is_inf) { - image_size_width = imgs->data->nx; - image_size_height = imgs->data->ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } } else if (ctx->has_qwen2vl_merger) { // use the image's native resolution when image is avaible if (is_inf) { // if (imgs->data->nx && imgs->data->ny) { - image_size_width = imgs->data->nx; - image_size_height = imgs->data->ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } } const int patch_size = hparams.patch_size; @@ -611,7 +583,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const bool use_window_attn = hparams.full_attn_layers.size() > 0; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; - const int batch_size = imgs->size; + const int batch_size = imgs.entries.size(); if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { GGML_ASSERT(batch_size == 1); @@ -623,7 +595,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im /*.no_alloc =*/ true, }; - struct ggml_context * ctx0 = ggml_init(params); + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); @@ -1181,7 +1155,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } } else { - GGML_ABORT("fatel error"); + GGML_ABORT("fatal error"); } } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { @@ -1213,12 +1187,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im // build the graph ggml_build_forward_expand(gf, embeddings); - ggml_free(ctx0); - return gf; } -static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { +static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { return clip_image_build_graph_siglip(ctx, imgs); } else { @@ -1396,7 +1368,7 @@ struct clip_model_loader { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - ctx_clip.ctx_data = ggml_init(params); + ctx_clip.ctx_data.reset(ggml_init(params)); if (!ctx_clip.ctx_data) { throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__)); } @@ -1410,7 +1382,7 @@ struct clip_model_loader { if (cur) { tensors_to_load.push_back(cur); // add tensors to context - struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data, cur); + struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur); ggml_set_name(data_tensor, cur->name); cur = data_tensor; } @@ -1583,11 +1555,11 @@ struct clip_model_loader { } // alloc memory and offload data - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); - ctx_clip.buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data, buft); - ggml_backend_buffer_set_usage(ctx_clip.buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend.get()); + ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); + ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); for (auto & t : tensors_to_load) { - struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data, t->name); + struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); const size_t offset = tensor_offset[t->name]; fin.seekg(offset, std::ios::beg); if (!fin) { @@ -1612,10 +1584,20 @@ struct clip_model_loader { void alloc_compute_meta() { ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); + + // create a fake batch clip_image_f32_batch batch; - batch.size = 1; - batch.data = nullptr; - ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, &batch, nullptr, false); + clip_image_f32_ptr img(clip_image_f32_init()); + clip_image_size image_size; + image_size.width = clip_get_image_size(&ctx_clip); + image_size.height = clip_get_image_size(&ctx_clip); + int n_patches = clip_get_image_size(&ctx_clip) / image_size.width; + img->nx = n_patches; + img->ny = n_patches; + img->buf.resize(n_patches * image_size.width * image_size.height * 3); + batch.entries.push_back(std::move(img)); + + ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false); ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) { ggml_backend_t backend = ctx_clip.backend_ptrs[i]; @@ -1716,11 +1698,11 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p } void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { - ctx_clip->load_image_size = load_image_size; + ctx_clip->load_image_size = *load_image_size; // copy } struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) { - return ctx_clip->load_image_size; + return &ctx_clip->load_image_size; } struct clip_image_size * clip_image_size_init() { @@ -1738,25 +1720,53 @@ struct clip_image_f32 * clip_image_f32_init() { return new clip_image_f32(); } +struct clip_image_f32_batch * clip_image_f32_batch_init() { + return new clip_image_f32_batch(); +} + +unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) { + if (nx) *nx = img->nx; + if (ny) *ny = img->ny; + return img->buf.data(); +} + void clip_image_size_free(struct clip_image_size * load_image_size) { if (load_image_size == nullptr) { return; } delete load_image_size; } -void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } -void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } -void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { - if (batch->size > 0) { - delete[] batch->data; - batch->size = 0; - } +void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; } +void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; } +void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; } +void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; } + +size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) { + return batch->entries.size(); } -void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { - if (batch->size > 0) { - delete[] batch->data; - batch->size = 0; + +size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return 0; } + return batch->entries[idx]->nx; +} + +size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return 0; + } + return batch->entries[idx]->ny; +} + +clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) { + if (idx < 0 || idx >= (int)batch->entries.size()) { + LOG_ERR("%s: invalid index %d\n", __func__, idx); + return nullptr; + } + return batch->entries[idx].get(); } void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) { @@ -1917,14 +1927,15 @@ static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int ta } // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not -static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) { - dst->nx = src->nx; - dst->ny = src->ny; - dst->buf.resize(src->buf.size()); +static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { + dst.nx = src.nx; + dst.ny = src.ny; + dst.buf.resize(src.buf.size()); - for (size_t i = 0; i < src->buf.size(); ++i) { + // TODO @ngxson : seems like this could be done more efficiently on cgraph + for (size_t i = 0; i < src.buf.size(); ++i) { int c = i % 3; // rgb - dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - mean[c]) / std[c]; + dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; } } @@ -1932,7 +1943,7 @@ inline int clip(int x, int lower, int upper) { return std::max(lower, std::min(x, upper)); } -static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) { +static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { const int nx = img.nx; const int ny = img.ny; @@ -2070,13 +2081,13 @@ static std::pair select_best_resolution(const std::pair & or return best_fit; } -static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { - std::vector patches; +static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { + std::vector patches; int width = image.nx; int height = image.ny; for (int i = 0; i < height; i += patch_size) { for (int j = 0; j < width; j += patch_size) { - clip_image_u8 *patch = clip_image_u8_init(); + clip_image_u8_ptr patch(clip_image_u8_init()); patch->nx = std::min(patch_size, width - j); patch->ny = std::min(patch_size, height - i); patch->buf.resize(3 * patch->nx * patch->ny); @@ -2087,7 +2098,7 @@ static std::vector divide_to_patches_u8(const clip_image_u8 & im } } } - patches.push_back(patch); + patches.push_back(std::move(patch)); } } return patches; @@ -2168,7 +2179,7 @@ static std::pair uhd_best_grid(const int max_slice_nums, const int mul // -> https://arxiv.org/pdf/2403.11703 // -> https://github.com/thunlp/LLaVA-UHD // -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 -static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { +static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { const std::pair original_size={img->nx,img->ny}; const int original_width = img->nx; const int original_height = img->ny; @@ -2176,30 +2187,30 @@ static std::vector> uhd_slice_image(const clip_imag const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); const int multiple = fmin(ceil(ratio), max_slice_nums); - std::vector> images; + std::vector> images; LOG_DBG("%s: multiple %d\n", __func__, multiple); - images.push_back(std::vector()); + images.push_back(std::vector()); if (multiple <= 1) { auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); - clip_image_u8 * source_image = clip_image_u8_init(); + clip_image_u8_ptr source_image(clip_image_u8_init()); bicubic_resize(*img, *source_image, best_size.first, best_size.second); // source_image = image.resize(best_size, Image.Resampling.BICUBIC) - images[images.size()-1].push_back(source_image); + images.back().push_back(std::move(source_image)); } else if (multiple > 1) { auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); - clip_image_u8 * source_image = clip_image_u8_init(); + clip_image_u8_ptr source_image(clip_image_u8_init()); bicubic_resize(*img, *source_image, best_size.first, best_size.second); // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); - images[images.size()-1].push_back(source_image); + images.back().push_back(std::move(source_image)); std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); - clip_image_u8 * refine_image = clip_image_u8_init(); + clip_image_u8_ptr refine_image(clip_image_u8_init()); bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); @@ -2210,9 +2221,9 @@ static std::vector> uhd_slice_image(const clip_imag int grid_x = int(width / best_grid.first); int grid_y = int(height / best_grid.second); for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ - images.push_back(std::vector()); + images.push_back(std::vector()); for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ - clip_image_u8 * patch = clip_image_u8_init(); + clip_image_u8_ptr patch(clip_image_u8_init()); patch->nx = grid_x; patch->ny = grid_y; patch->buf.resize(3 * patch->nx * patch->ny); @@ -2225,10 +2236,9 @@ static std::vector> uhd_slice_image(const clip_imag patch->buf[j+2] = refine_image->buf[i+2]; } } - images[images.size()-1].push_back(patch); + images.back().push_back(std::move(patch)); } } - clip_image_u8_free(refine_image); } return images; } @@ -2236,8 +2246,8 @@ static std::vector> uhd_slice_image(const clip_imag int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { const int max_slice_nums=9; const int scale_resolution=448; - const int original_width = ctx_clip->load_image_size->width; - const int original_height = ctx_clip->load_image_size->height; + const int original_width = ctx_clip->load_image_size.width; + const int original_height = ctx_clip->load_image_size.height; const float log_ratio = log(1.0*original_width/original_height); const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); const int multiple = fmin(ceil(ratio), max_slice_nums); @@ -2247,64 +2257,44 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found -bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { +bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { - if(clip_is_minicpmv(ctx)){ + if (clip_is_minicpmv(ctx)) { int max_slice_nums = 9; - std::vector> imgs = uhd_slice_image(img, max_slice_nums); - res_imgs->size = 0; - for (size_t i = 0; i < imgs.size(); ++i){ - res_imgs->size += imgs[i].size(); - } - res_imgs->data = new clip_image_f32[res_imgs->size]; - int idx = 0; + std::vector> imgs = uhd_slice_image(img, max_slice_nums); for (size_t i = 0; i < imgs.size(); ++i) { for (size_t j = 0; j < imgs[i].size(); ++j) { LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); - clip_image_f32 * res = clip_image_f32_init(); - normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std); - res_imgs->data[idx++] = *res; - clip_image_f32_free(res); - } - } - for (size_t i = 0; i < imgs.size(); ++i) { - for (size_t j = 0; j < imgs[i].size(); ++j) { - if (imgs[i][j] != nullptr) { - clip_image_u8_free(imgs[i][j]); - } + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); } } return true; } else if (ctx->has_qwen2vl_merger) { - clip_image_u8 * resized = clip_image_u8_init(); - auto patch_size = clip_patch_size(ctx) * 2; + clip_image_u8 resized; + auto patch_size = clip_get_patch_size(ctx) * 2; int nx = ceil((float)img->nx / patch_size) * patch_size; int ny = ceil((float)img->ny / patch_size) * patch_size; - bicubic_resize(*img, *resized, nx, ny); + bicubic_resize(*img, resized, nx, ny); - res_imgs->data = new clip_image_f32[1]; - // clip_image_f32 * res = clip_image_f32_init(); - normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + // clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std); // res_imgs->data[0] = *res; - res_imgs->size = 1; - - // clip_image_f32_free(res); - clip_image_u8_free(resized); + res_imgs->entries.push_back(std::move(img_f32)); return true; } if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - res_imgs->size = 1; - res_imgs->data = new clip_image_f32[res_imgs->size]; clip_image_u8 resized_image; int32_t sz=ctx->vision_model.hparams.image_size; bicubic_resize(*img, resized_image,sz,sz); - clip_image_f32 * res = clip_image_f32_init(); + clip_image_f32_ptr img_f32(clip_image_f32_init()); //clip_image_save_to_bmp(resized_image, "resized.bmp"); - normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std); - res_imgs->data[0] = *res; - clip_image_f32_free(res); + normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(img_f32)); return true; } @@ -2319,16 +2309,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli pad_to_square = false; } // free the previous res_imgs if any set - if (res_imgs->size > 0) { - clip_image_f32_batch_free(res_imgs); - } - res_imgs->data = nullptr; - res_imgs->size = 0; + res_imgs->entries.clear(); // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 - clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily + clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily if (pad_to_square && img->nx != img->ny) { int longer_side = std::max(img->nx, img->ny); temp->nx = longer_side; @@ -2371,28 +2357,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli // clip_image_u8_free(temp2); // } - std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) - clip_image_u8 *image_original_resize = clip_image_u8_init(); + clip_image_u8_ptr image_original_resize(clip_image_u8_init()); // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - patches.insert(patches.begin(), image_original_resize); - // clip_image_f32_batch_init(patches.size()); - res_imgs->size = patches.size(); - res_imgs->data = new clip_image_f32[res_imgs->size]; - int num=0; - for (auto& patch : patches) { - normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std); - num++; + patches.insert(patches.begin(), std::move(image_original_resize)); + for (auto & patch : patches) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); } - for (size_t i = 0; i < patches.size(); i++) { - // LOG_DBG("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny); - clip_image_u8_free(patches[i]); - } - - clip_image_u8_free(temp); - return true; } else { temp->nx = img->nx; @@ -2408,7 +2384,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli const int nx2 = ctx->vision_model.hparams.image_size; const int ny2 = ctx->vision_model.hparams.image_size; - clip_image_f32 * res = clip_image_f32_init(); + clip_image_f32_ptr res(clip_image_f32_init()); res->nx = nx2; res->ny = ny2; res->buf.resize(3 * nx2 * ny2); @@ -2460,7 +2436,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli } } } - clip_image_u8_free(temp); // { // clip_image_u8 * temp2 = clip_image_u8_init(); @@ -2470,10 +2445,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli // } // res_imgs.push_back(res); - res_imgs->size = 1; - res_imgs->data = new clip_image_f32[res_imgs->size]; - res_imgs->data[0] = *res; - clip_image_f32_free(res); + res_imgs->entries.push_back(std::move(res)); return true; } @@ -2501,15 +2473,15 @@ size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float); } -int32_t clip_image_size(const struct clip_ctx * ctx) { +int32_t clip_get_image_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.image_size; } -int32_t clip_patch_size(const struct clip_ctx * ctx) { +int32_t clip_get_patch_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.patch_size; } -int32_t clip_hidden_size(const struct clip_ctx * ctx) { +int32_t clip_get_hidden_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.hidden_size; } @@ -2561,6 +2533,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); n_patches = x_patch * y_patch; + } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + n_patches = 256; } return n_patches; @@ -2658,19 +2632,23 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 return false; } - clip_image_f32_batch imgs{}; - imgs.size = 1; - imgs.data = img; + clip_image_f32_batch imgs; + clip_image_f32_ptr img_copy(clip_image_f32_init()); + *img_copy = *img; + imgs.entries.push_back(std::move(img_copy)); + return clip_image_batch_encode(ctx, n_threads, &imgs, vec); } -bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { +bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) { + const clip_image_f32_batch & imgs = *imgs_c_ptr; + if (!ctx->has_vision_encoder) { LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); return false; } - int batch_size = imgs->size; + int batch_size = imgs.entries.size(); if (ctx->has_llava_projector) { GGML_ASSERT(batch_size == 1); // TODO: support multiple images } @@ -2697,25 +2675,22 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima int image_size_width = image_size; int image_size_height = image_size; if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) { - image_size_width = imgs->data[0].nx; - image_size_height = imgs->data[0].ny; + image_size_width = imgs.entries[0]->nx; + image_size_height = imgs.entries[0]->ny; } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (model.class_embedding ? 1 : 0); - if(ctx->load_image_size==nullptr){ - ctx->load_image_size= clip_image_size_init(); - } - const int pos_w = ctx->load_image_size->width/patch_size; - const int pos_h = ctx->load_image_size->height/patch_size; + const int pos_w = ctx->load_image_size.width / patch_size; + const int pos_h = ctx->load_image_size.height / patch_size; { struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); float * data = (float *)malloc(ggml_nbytes(inp_raw)); - for (size_t i = 0; i < imgs->size; i++) { - const int nx = imgs->data[i].nx; - const int ny = imgs->data[i].ny; + for (size_t i = 0; i < imgs.entries.size(); i++) { + const int nx = imgs.entries[i]->nx; + const int ny = imgs.entries[i]->ny; if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) { GGML_ASSERT(nx == image_size && ny == image_size); } @@ -2726,7 +2701,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (int k = 0; k < 3; k++) { for (int y = 0; y < ny; y++) { for (int x = 0; x < nx; x++) { - data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k]; + data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k]; } } } @@ -3011,8 +2986,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i /* verbosity */ GGML_LOG_LEVEL_ERROR, }); - const auto & ctx_src = ctx_clip->ctx_gguf; - const auto & ctx_data = ctx_clip->ctx_data; + const auto & ctx_src = ctx_clip->ctx_gguf.get(); + const auto & ctx_data = ctx_clip->ctx_data.get(); auto * ctx_out = gguf_init_empty(); gguf_set_kv(ctx_out, ctx_src); @@ -3247,3 +3222,11 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, clip_image_encode(ctx, n_threads, &clip_img, vec); return true; } + +// +// API used internally with mtmd +// + +projector_type clip_get_projector_type(const struct clip_ctx * ctx) { + return ctx->proj_type; +} diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 210c384e8..fdc0fa578 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -30,15 +30,8 @@ struct clip_image_size { int height; }; -struct clip_image_u8_batch { - struct clip_image_u8 * data; - size_t size; -}; - -struct clip_image_f32_batch { - struct clip_image_f32 * data; - size_t size; -}; +struct clip_image_u8_batch; +struct clip_image_f32_batch; struct clip_context_params { bool use_gpu; @@ -55,9 +48,9 @@ CLIP_API void clip_free(struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w); -CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx); +CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx); // TODO: should be enum, not string CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); @@ -73,9 +66,13 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); -CLIP_API struct clip_image_size * clip_image_size_init(); -CLIP_API struct clip_image_u8 * clip_image_u8_init (); -CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API struct clip_image_size * clip_image_size_init(); +CLIP_API struct clip_image_u8 * clip_image_u8_init (); +CLIP_API struct clip_image_f32 * clip_image_f32_init(); +CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava + +// nx, ny are the output image dimensions +CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny); CLIP_API void clip_image_size_free (struct clip_image_size * img_size); CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); @@ -83,6 +80,12 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); +// use for accessing underlay data of clip_image_f32_batch +CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size() +CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx +CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny +CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data + /** * Build image from pixels decoded by other libraries instead of stb_image.h for better performance. * The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 4f89c0e15..91a07e2a8 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -2,11 +2,11 @@ #include "log.h" #include "common.h" #include "sampling.h" -#include "clip.h" -#include "stb_image.h" #include "llama.h" #include "ggml.h" #include "console.h" +#include "chat.h" +#include "mtmd.h" #include #include @@ -57,13 +57,18 @@ static void sigint_handler(int signo) { #endif struct gemma3_context { - struct clip_ctx * ctx_clip = NULL; - common_init_result llama_init; + mtmd_context_ptr ctx_vision; + common_init_result llama_init; llama_model * model; llama_context * lctx; const llama_vocab * vocab; llama_batch batch; + int n_batch; + + // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another + // so here we don't need to keep track of chat history + common_chat_templates_ptr tmpls; int n_threads = 1; llama_pos n_past = 0; @@ -74,21 +79,24 @@ struct gemma3_context { vocab = llama_model_get_vocab(model); n_threads = params.cpuparams.n_threads; batch = llama_batch_init(params.n_batch, 0, 1); - init_clip_model(params); + n_batch = params.n_batch; + tmpls = common_chat_templates_init(model, params.chat_template); + init_vision_context(params); } - void init_clip_model(common_params & params) { + void init_vision_context(common_params & params) { const char * clip_path = params.mmproj.path.c_str(); - ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO); - if (!ctx_clip) { - LOG_ERR("Failed to load CLIP model from %s\n", clip_path); + ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{ + /* use_gpu */ true, + /* timings */ true, + /* n_threads */ params.cpuparams.n_threads, + /* verbosity */ GGML_LOG_LEVEL_INFO, + })); + if (!ctx_vision.get()) { + LOG_ERR("Failed to load vision model from %s\n", clip_path); exit(1); } } - - ~gemma3_context() { - clip_free(ctx_clip); - } }; struct decode_embd_batch { @@ -124,77 +132,6 @@ struct decode_embd_batch { } }; -static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { - llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); - common_batch_clear(ctx.batch); - for (llama_token & t : tokens) { - common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); - } - if (logits_last) { - ctx.batch.logits[ctx.batch.n_tokens - 1] = true; - } - // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); - if (llama_decode(ctx.lctx, ctx.batch)) { - LOG_ERR("Failed to decode text\n"); - return 1; - } - return 0; -} - -static int eval_image(gemma3_context & ctx, std::string & fname) { - std::vector image_embd_v; - int n_embd = llama_model_n_embd(ctx.model); - int n_tokens = 256; - image_embd_v.resize(n_tokens * n_embd); - - bool ok; - struct clip_image_u8 * img_u8 = clip_image_u8_init(); - ok = clip_image_load_from_file(fname.c_str(), img_u8); - if (!ok) { - LOG_ERR("Unable to load image %s\n", fname.c_str()); - clip_image_u8_free(img_u8); - return 2; // non-fatal error - } - - clip_image_f32_batch batch_f32; - ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32); - if (!ok) { - LOG_ERR("Unable to preprocess image\n"); - clip_image_f32_batch_free(&batch_f32); - clip_image_u8_free(img_u8); - return 1; - } - - int64_t t0 = ggml_time_ms(); - LOG("Encoding image %s\n", fname.c_str()); - ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data()); - if (!ok) { - LOG_ERR("Unable to encode image\n"); - clip_image_f32_batch_free(&batch_f32); - clip_image_u8_free(img_u8); - return 1; - } - LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); - - clip_image_f32_batch_free(&batch_f32); - clip_image_u8_free(img_u8); - - // decode image embeddings - int64_t t1 = ggml_time_ms(); - eval_text(ctx, ""); - llama_set_causal_attn(ctx.lctx, false); - decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); - if (llama_decode(ctx.lctx, batch_img.batch)) { - LOG_ERR("failed to decode image\n"); - return 1; - } - ctx.n_past += n_tokens; - llama_set_causal_attn(ctx.lctx, true); - eval_text(ctx, ""); - LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1); - return 0; -} - static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) { for (int i = 0; i < n_predict; i++) { if (i > n_predict || !g_is_generating) { @@ -224,6 +161,45 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ return 0; } +static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector & images_fname, bool add_bos = false) { + std::vector bitmaps; + + common_chat_templates_inputs tmpl_inputs; + tmpl_inputs.messages = {msg}; + tmpl_inputs.add_generation_prompt = true; + tmpl_inputs.use_jinja = false; // jinja is buggy here + auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs); + LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str()); + + for (auto & fname : images_fname) { + mtmd_bitmap bitmap; + if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) { + LOG_ERR("Unable to load image %s\n", fname.c_str()); + return 2; // image not found + } + bitmaps.push_back(std::move(bitmap)); + } + + mtmd_input_text text; + text.text = formatted_chat.prompt; + text.add_special = add_bos; + text.parse_special = true; + mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps)); + if (chunks == nullptr) { + LOG_ERR("Unable to tokenize prompt\n"); + return 1; + } + + if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) { + LOG_ERR("Unable to eval prompt\n"); + return 1; + } + + ctx.n_past += mtmd_helper_get_n_tokens(chunks.get()); + + return 0; +} + int main(int argc, char ** argv) { ggml_time_init(); @@ -265,21 +241,15 @@ int main(int argc, char ** argv) { #endif } - if (eval_text(ctx, "")) { - return 1; - } - if (is_single_turn) { g_is_generating = true; - if (eval_text(ctx, "user\n")) { - return 1; + if (params.prompt.find("<__image__>") == std::string::npos) { + params.prompt += " <__image__>"; } - for (auto & fname : params.image) { - if (eval_image(ctx, fname)) { - return 1; - } - } - if (eval_text(ctx, params.prompt + "model\n", true)) { + common_chat_msg msg; + msg.role = "user"; + msg.content = params.prompt; + if (eval_message(ctx, msg, params.image, true)) { return 1; } if (generate_response(ctx, smpl, n_predict)) { @@ -293,9 +263,9 @@ int main(int argc, char ** argv) { LOG("\n /quit or /exit exit the program"); LOG("\n"); - if (eval_text(ctx, "user\n")) { - return 1; - } + bool is_first_msg = true; + std::vector images_fname; + std::string content; while (true) { g_is_generating = false; @@ -320,24 +290,31 @@ int main(int argc, char ** argv) { g_is_generating = true; if (line.find("/image") == 0) { std::string image = line.substr(7); - int res = eval_image(ctx, image); - if (res == 2) { - continue; // image not found - } - if (res) { - return 1; - } + images_fname.push_back(string_strip(image)); + content += "<__image__>"; + continue; + } else { + content += line; + } + common_chat_msg msg; + msg.role = "user"; + msg.content = content; + int ret = eval_message(ctx, msg, images_fname, is_first_msg); + if (ret == 2) { + // non-fatal error + images_fname.clear(); + content.clear(); continue; } - if (eval_text(ctx, line + "model\n", true)) { + if (ret) { return 1; } if (generate_response(ctx, smpl, n_predict)) { return 1; } - if (eval_text(ctx, "user\n")) { - return 1; - } + images_fname.clear(); + content.clear(); + is_first_msg = false; } } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 887698c4a..fcb7749ad 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #if defined(LLAVA_LOG_OFF) # define LOG_INF(...) @@ -45,6 +46,17 @@ struct clip_image_grid_shape { int second; }; +// convenience cpp wrapper +struct clip_image_f32_batch_deleter { + void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } +}; +typedef std::unique_ptr clip_image_f32_batch_ptr; + +struct clip_image_size_deleter { + void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } +}; +typedef std::unique_ptr clip_image_size_ptr; + /** * Selects the best resolution from a list of possible resolutions based on the original size. * @@ -105,8 +117,8 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector struct ggml_context * ctx; } model; - const int32_t image_size = clip_image_size(ctx_clip); - const int32_t patch_size = clip_patch_size(ctx_clip); + const int32_t image_size = clip_get_image_size(ctx_clip); + const int32_t patch_size = clip_get_patch_size(ctx_clip); int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) @@ -246,12 +258,9 @@ static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 - clip_image_f32_batch img_res_v; - img_res_v.size = 0; - img_res_v.data = nullptr; - if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) { + clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init()); + if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) { LOG_ERR("%s: unable to preprocess image\n", __func__); - delete[] img_res_v.data; return false; } @@ -259,66 +268,72 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); + const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get()); + if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) { std::vector image_embd_v; - image_embd_v.resize(img_res_v.size); - struct clip_image_size * load_image_size = clip_image_size_init(); + image_embd_v.resize(n_imgs); + clip_image_size load_image_size; - for (size_t i = 0; i < img_res_v.size; i++) { + for (size_t i = 0; i < n_imgs; i++) { const int64_t t_img_enc_step_start_us = ggml_time_us(); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny)); - int patch_size=14; - load_image_size->width = img_res_v.data[i].nx; - load_image_size->height = img_res_v.data[i].ny; - clip_add_load_image_size(ctx_clip, load_image_size); + int nx = clip_image_f32_batch_nx(img_res_v.get(), i); + int ny = clip_image_f32_batch_ny(img_res_v.get(), i); + image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny)); + int patch_size = 14; + load_image_size.width = nx; + load_image_size.height = ny; + clip_add_load_image_size(ctx_clip, &load_image_size); bool encoded = false; + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); if (clip_is_qwen2vl(ctx_clip)) { - encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); + encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); } else { - encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]); } if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); return false; } const int64_t t_img_enc_steop_batch_us = ggml_time_us(); - LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); + LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); } const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); int n_img_pos_out = 0; for (size_t i = 0; i < image_embd_v.size(); i++) { + int nx = clip_image_f32_batch_nx(img_res_v.get(), i); + int ny = clip_image_f32_batch_ny(img_res_v.get(), i); + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); std::memcpy( image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], - clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny)); - n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]); + clip_embd_nbytes_by_img(ctx_clip, nx, ny)); + n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res); } *n_img_pos = n_img_pos_out; for (size_t i = 0; i < image_embd_v.size(); i++) { free(image_embd_v[i]); } image_embd_v.clear(); - load_image_size->width = img->nx; - load_image_size->height = img->ny; - clip_add_load_image_size(ctx_clip, load_image_size); - LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); - delete[] img_res_v.data; - img_res_v.size = 0; - img_res_v.data = nullptr; + load_image_size.width = img->nx; + load_image_size.height = img->ny; + clip_add_load_image_size(ctx_clip, &load_image_size); + LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height); } else if (clip_is_glm(ctx_clip)){ struct clip_image_size * load_image_size = clip_image_size_init(); - load_image_size->width = img_res_v.data[0].nx; - load_image_size->height = img_res_v.data[0].ny; + load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0); + load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0); clip_add_load_image_size(ctx_clip, load_image_size); - bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); - int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2); + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); + bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); + int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2); *n_img_pos = (pos * pos + 2); if (!encoded){ LOG_ERR("Unable to encode image \n"); @@ -328,8 +343,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding *n_img_pos = clip_n_patches(ctx_clip); - bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 - delete[] img_res_v.data; + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); + bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096 if (!encoded) { LOG_ERR("Unable to encode image\n"); @@ -340,17 +355,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli // spatial_unpad llava-1.6 type embedding // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working std::vector image_embd_v; - image_embd_v.resize(img_res_v.size); - for (size_t i = 0; i < img_res_v.size; i++) { + image_embd_v.resize(n_imgs); + for (size_t i = 0; i < n_imgs; i++) { + clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside + const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); return false; } } const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); const int32_t * image_grid = clip_image_grid(ctx_clip); const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); @@ -360,12 +376,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); } - // free all img_res_v - not needed anymore - delete[] img_res_v.data; - img_res_v.size = 0; - img_res_v.data = nullptr; - - const int32_t image_size = clip_image_size(ctx_clip); + const int32_t image_size = clip_get_image_size(ctx_clip); struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size); diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp new file mode 100644 index 000000000..114c274bc --- /dev/null +++ b/examples/llava/mtmd.cpp @@ -0,0 +1,341 @@ +#include "clip.h" +#include "clip-impl.h" +#include "mtmd.h" + +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include + +struct mtmd_context { + struct clip_ctx * ctx_clip; + const struct llama_model * text_model; + std::vector image_embd_v; // image embedding vector + bool print_timings; + int n_threads; + std::string image_marker; + + // TODO @ngxson : add timings + + mtmd_context(const char * mmproj_fname, + const llama_model * text_model, + const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) { + clip_context_params ctx_clip_params; + ctx_clip_params.use_gpu = ctx_params.use_gpu; + ctx_clip_params.verbosity = ctx_params.verbosity; + ctx_clip = clip_init(mmproj_fname, ctx_clip_params); + if (!ctx_clip) { + throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname)); + } + this->text_model = text_model; + } + + ~mtmd_context() { + clip_free(ctx_clip); + } +}; + +struct mtmd_image_tokens_data { + clip_image_f32_batch batch_f32; // preprocessed image patches +}; + +struct mtmd_image_tokens { + uint32_t nx; // number of tokens in x direction + uint32_t ny; // number of tokens in y direction + uint32_t n_tokens() const { return nx * ny; } + clip_image_f32_batch batch_f32; // preprocessed image patches +}; + +mtmd_context * mtmd_init_from_file(const char * mmproj_fname, + const struct llama_model * text_model, + const struct mtmd_context_params ctx_params) { + try { + return new mtmd_context(mmproj_fname, text_model, ctx_params); + } catch (const std::exception & e) { + LOG_ERR("%s: error: %s\n", __func__, e.what()); + return nullptr; + } +} + +void mtmd_free(mtmd_context * ctx) { + if (ctx) { + delete ctx; + } +} + +// copied from common_tokenize +static std::vector mtmd_tokenize_text_internal( + const struct llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special) { + // upper limit for the number of tokens + int n_tokens = text.length() + 2 * add_special; + std::vector result(n_tokens); + n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, + const mtmd_input_text & text, + const std::vector & bitmaps) { + mtmd_input_chunks * output = new mtmd_input_chunks; + auto vocab = llama_model_get_vocab(ctx->text_model); + + std::string prompt_modified(text.text); + std::string marker_modified(ctx->image_marker); + projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); + // a bit hacky here, but works for now + // for some models, we need to add prefix and suffix to the image embeddings + if (proj_type == PROJECTOR_TYPE_GEMMA3) { + // ... (image embeddings) ... + marker_modified = "" + ctx->image_marker + ""; + string_replace_all(prompt_modified, ctx->image_marker, marker_modified); + } + + std::vector parts = string_split_str(text.text, ctx->image_marker); + output->clear(); + output->reserve(parts.size()); + + size_t i_img = 0; + + for (const auto & part : parts) { + //printf("tokenizing part: %s\n", part.c_str()); + bool add_bos = &parts.front() == ∂ + auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special); + if (tokens.empty()) { + continue; + } + mtmd_input_chunk chunk{ + MTMD_INPUT_CHUNK_TYPE_TEXT, + std::move(tokens), + {}, + }; + output->emplace_back(std::move(chunk)); + + if (&parts.back() != &part) { + // add image token to middle of 2 parts + + if (i_img >= bitmaps.size()) { + LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size()); + return nullptr; + } + + // shim layer + clip_image_u8_ptr img_u8(clip_image_u8_init()); + img_u8->nx = bitmaps[i_img].nx; + img_u8->ny = bitmaps[i_img].ny; + img_u8->buf.resize(bitmaps[i_img].data.size()); + std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3); + + // preprocess image + clip_image_f32_batch batch_f32; + bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32); + if (!ok) { + LOG_ERR("Unable to preprocess image\n"); + return nullptr; + } + + mtmd_image_tokens * image_tokens = new mtmd_image_tokens; + image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image + image_tokens->ny = 1; // TODO + image_tokens->batch_f32 = std::move(batch_f32); + + mtmd_input_chunk chunk{ + MTMD_INPUT_CHUNK_TYPE_IMAGE, + {}, + image_tokens, + }; + output->emplace_back(std::move(chunk)); + i_img++; + } + } + + return output; +} + +void mtmd_input_chunks_free(mtmd_input_chunks * chunks) { + for (auto & chunk : *chunks) { + if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) { + delete chunk.tokens_image; + } + } + delete chunks; +} + +int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { + int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); + ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); + bool ok = clip_image_batch_encode( + ctx->ctx_clip, + ctx->n_threads, + &image_tokens->batch_f32, + ctx->image_embd_v.data()); + return ok ? 0 : 1; +} + +float * mtmd_get_output_embd(mtmd_context * ctx) { + return ctx->image_embd_v.data(); +} + +size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) { + size_t n_tokens = 0; + for (auto & chunk : *chunks) { + if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + n_tokens += chunk.tokens_text.size(); + } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + n_tokens += chunk.tokens_image->n_tokens(); + } else { + GGML_ASSERT(false && "chunk type not supported"); + } + } + return n_tokens; +} + +// helper struct to make working with embd batch easier +// note: this will be removed after llama_batch_ext refactoring +struct decode_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + +int32_t mtmd_helper_eval(mtmd_context * ctx, + llama_context * lctx, + mtmd_input_chunks * chunks, + llama_pos pos0, + llama_seq_id seq_id, + int32_t n_batch) { + int32_t ret; + llama_pos n_past = pos0; + llama_batch text_batch = llama_batch_init(n_batch, 0, 1); + + for (auto & chunk : *chunks) { + bool is_last = &chunk == &chunks->back(); + if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + // TODO @ngxson : may need to split into smaller batches + text_batch.n_tokens = chunk.tokens_text.size(); + for (size_t i = 0; i < chunk.tokens_text.size(); i++) { + text_batch.token [i] = chunk.tokens_text[i]; + text_batch.pos [i] = n_past++; + text_batch.n_seq_id[i] = 1; + text_batch.seq_id [i][0] = seq_id; + text_batch.logits [i] = false; + } + if (is_last) { + // always get logits for last input chunk + text_batch.logits[text_batch.n_tokens - 1] = true; + } + ret = llama_decode(lctx, text_batch); + if (ret != 0) { + LOG_ERR("failed to decode text\n"); + llama_batch_free(text_batch); + return ret; + } + + } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + GGML_ASSERT(!is_last && "logits for last image chunk is not yet support"); + GGML_ASSERT(chunk.tokens_image != nullptr); + int64_t t0 = ggml_time_ms(); + if (ctx->print_timings) { + LOG_INF("encoding image...\n"); + } + ret = mtmd_encode(ctx, chunk.tokens_image); + if (ret != 0) { + LOG_ERR("failed to encode image\n"); + llama_batch_free(text_batch); + return ret; + } + if (ctx->print_timings) { + LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); + } + + int32_t n_tokens = chunk.tokens_image->n_tokens(); + float * embd = mtmd_get_output_embd(ctx); + decode_embd_batch batch_img(embd, n_tokens, n_past, 0); + int64_t t1 = ggml_time_ms(); + ret = llama_decode(lctx, batch_img.batch); + if (ret != 0) { + LOG_ERR("failed to decode image\n"); + llama_batch_free(text_batch); + return ret; + } + if (ctx->print_timings) { + LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1); + } + + n_past += n_tokens; + + } else { + GGML_ASSERT(false && "chunk type not supported"); + } + } + + llama_batch_free(text_batch); + return 0; +} + +int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) { + clip_image_u8_ptr img_u8(clip_image_u8_init()); + bool ok = clip_image_load_from_bytes(buf, len, img_u8.get()); + if (!ok) { + LOG_ERR("Unable to load image from buffer\n"); + return 1; + } + unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny); + output.data.resize(output.nx * output.ny * 3); + std::memcpy(output.data.data(), data, output.nx * output.ny * 3); + return 0; +} + +int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) { + clip_image_u8_ptr img_u8(clip_image_u8_init()); + bool ok = clip_image_load_from_file(fname, img_u8.get()); + if (!ok) { + LOG_ERR("Unable to load image %s\n", fname); + return 1; + } + unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny); + output.data.resize(output.nx * output.ny * 3); + std::memcpy(output.data.data(), data, output.nx * output.ny * 3); + return 0; +} diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h new file mode 100644 index 000000000..598f6947b --- /dev/null +++ b/examples/llava/mtmd.h @@ -0,0 +1,146 @@ +#ifndef MTMD_H +#define MTMD_H + +#include "ggml.h" +#include "llama.h" +#include "clip.h" + +#include +#include +#include + +#ifdef LLAMA_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef LLAMA_BUILD +# define MTMD_API __declspec(dllexport) +# else +# define MTMD_API __declspec(dllimport) +# endif +# else +# define MTMD_API __attribute__ ((visibility ("default"))) +# endif +#else +# define MTMD_API +#endif + +#ifdef __cplusplus + +enum mtmd_input_chunk_type { + MTMD_INPUT_CHUNK_TYPE_TEXT, + MTMD_INPUT_CHUNK_TYPE_IMAGE, +}; + +struct mtmd_context; +struct mtmd_image_tokens; + +// represents raw image data, layout is RGBRGBRGB... +// length of data must be nx * ny * 3 +struct mtmd_bitmap { + uint32_t nx; + uint32_t ny; + std::vector data; +}; + +struct mtmd_input_chunk { + mtmd_input_chunk_type type; + std::vector tokens_text; + mtmd_image_tokens * tokens_image = nullptr; +}; + +using mtmd_input_chunks = std::vector; + +struct mtmd_context_params { + bool use_gpu = true; + bool print_timings = true; + int n_threads = 4; + enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO; + const char * image_marker = "<__image__>"; +}; + +struct mtmd_input_text { + std::string text; + bool add_special; + bool parse_special; +}; + +// initialize the mtmd context +// return nullptr on failure +MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, + const llama_model * text_model, + const mtmd_context_params ctx_params); + +MTMD_API void mtmd_free(mtmd_context * ctx); + +// tokenize an input text prompt and an image +// the prompt must have the input image marker (default: "<__image__>") in it +// the marker will be replaced with the image tokens +// for example: +// "here is an image: <__image__>\ndescribe it in detail." +// this will gives 3 chunks: +// 1. "here is an image: " +// 2. (image tokens) +// 3. "\ndescribe it in detail." +// number of bitmaps must be equal to the number of image markers in the prompt +// this function is thread-safe (shared ctx) +MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, + const mtmd_input_text & text, + const std::vector & bitmaps); + +// free image chunk data +MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); + +// returns 0 on success +MTMD_API int32_t mtmd_encode(mtmd_context * ctx, + const mtmd_image_tokens * image_tokens); + +// get output embeddings from the last encode pass +MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); + +// +// helper functions (can be implemented based on other functions) +// + +// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past +MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks); + +// helper function that automatically: +// 1. run llama_decode() on text chunks +// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode() +// if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error +// otherwise, returns 0 on success +MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx, + llama_context * lctx, + mtmd_input_chunks * chunks, + llama_pos pos0, + llama_seq_id seq_id, + int32_t n_batch); + +// helper function to construct a mtmd_bitmap from a file +// returns 0 on success +// this function is thread-safe +MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output); + +// helper function to construct a mtmd_bitmap from a buffer +// the buffer must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.) +// returns 0 on success +// this function is thread-safe +MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output); + +// convenient unique_ptr wrappers +struct mtmd_context_deleter { + void operator()(mtmd_context * val) { mtmd_free(val); } +}; +using mtmd_context_ptr = std::unique_ptr; + +struct mtmd_input_chunks_deleter { + void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); } +}; +using mtmd_input_chunks_ptr = std::unique_ptr; + +#else + +static_assert(false && "C header is not yet supported by this library"); + +#endif + +#endif diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1bf1ee876..d87bda1a0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3907,6 +3907,21 @@ int main(int argc, char ** argv) { res_ok(res, {{ "success", true }}); }; + const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + json data = { + { + "template", common_chat_templates_source(ctx_server.chat_templates.get()), + }, + { + "model_info", { + { "llama.context_length", ctx_server.slots.back().n_ctx, }, + } + }, + }; + + res_ok(res, data); + }; + // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( @@ -4471,6 +4486,7 @@ int main(int argc, char ** argv) { svr->Get ("/metrics", handle_metrics); svr->Get ("/props", handle_props); svr->Post("/props", handle_props_change); + svr->Post("/api/show", handle_api_show); svr->Get ("/models", handle_models); // public endpoint (no API key check) svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) svr->Post("/completion", handle_completions); // legacy diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index bccd37ad7..7c42e8eb0 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -513,17 +513,12 @@ extern "C" { GGML_OP_UNARY, - GGML_OP_MAP_UNARY, - GGML_OP_MAP_BINARY, - - GGML_OP_MAP_CUSTOM1_F32, - GGML_OP_MAP_CUSTOM2_F32, - GGML_OP_MAP_CUSTOM3_F32, - GGML_OP_MAP_CUSTOM1, GGML_OP_MAP_CUSTOM2, GGML_OP_MAP_CUSTOM3, + GGML_OP_CUSTOM, + GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, @@ -1735,24 +1730,29 @@ extern "C" { float p0, float p1); - // nearest interpolate + enum ggml_scale_mode { + GGML_SCALE_MODE_NEAREST = 0, + GGML_SCALE_MODE_BILINEAR = 1, + }; + + // interpolate // multiplies ne0 and ne1 by scale factor - // used in stable-diffusion GGML_API struct ggml_tensor * ggml_upscale( struct ggml_context * ctx, struct ggml_tensor * a, - int scale_factor); + int scale_factor, + enum ggml_scale_mode mode); - // nearest interpolate - // nearest interpolate to specified dimensions - // used in tortoise.cpp + // interpolate + // interpolate scale to specified dimensions GGML_API struct ggml_tensor * ggml_upscale_ext( struct ggml_context * ctx, struct ggml_tensor * a, int ne0, int ne1, int ne2, - int ne3); + int ne3, + enum ggml_scale_mode mode); // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] GGML_API struct ggml_tensor * ggml_pad( @@ -1929,83 +1929,6 @@ extern "C" { // custom operators - typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); - typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); - - typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *); - typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); - typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_unary_op_f32_t fun), - "use ggml_map_custom1 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_unary_op_f32_t fun), - "use ggml_map_custom1_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_binary_op_f32_t fun), - "use ggml_map_custom2 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_binary_op_f32_t fun), - "use ggml_map_custom2_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_f32_t fun), - "use ggml_map_custom1 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_f32_t fun), - "use ggml_map_custom1_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_f32_t fun), - "use ggml_map_custom2 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_f32_t fun), - "use ggml_map_custom2_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_f32_t fun), - "use ggml_map_custom3 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_f32_t fun), - "use ggml_map_custom3_inplace instead"); - - // custom operators v2 - typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); @@ -2061,6 +1984,30 @@ extern "C" { int n_tasks, void * userdata); + typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata); + + GGML_API struct ggml_tensor * ggml_custom_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata); + // loss function GGML_API struct ggml_tensor * ggml_cross_entropy_loss( diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ad497e4c9..e03917b29 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2041,41 +2041,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rwkv_wkv7(params, tensor); } break; - case GGML_OP_MAP_UNARY: - { - ggml_unary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_unary(params, tensor, fun); - } - break; - case GGML_OP_MAP_BINARY: - { - ggml_binary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_binary(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1_F32: - { - ggml_custom1_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom1_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM2_F32: - { - ggml_custom2_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom2_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM3_F32: - { - ggml_custom3_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom3_f32(params, tensor, fun); - } - break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2091,6 +2056,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_map_custom3(params, tensor); } break; + case GGML_OP_CUSTOM: + { + ggml_compute_forward_custom(params, tensor); + } + break; case GGML_OP_CROSS_ENTROPY_LOSS: { ggml_compute_forward_cross_entropy_loss(params, tensor); @@ -2342,11 +2312,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: { n_tasks = 1; } break; @@ -2380,6 +2345,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = MIN(p.n_tasks, n_threads); } } break; + case GGML_OP_CUSTOM: + { + struct ggml_custom_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index f63656be5..6050147be 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6351,24 +6351,72 @@ static void ggml_compute_forward_upscale_f32( const float sf2 = (float)ne2/src0->ne[2]; const float sf3 = (float)ne3/src0->ne[3]; - // TODO: optimize + const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); - for (int64_t i3 = 0; i3 < ne3; i3++) { - const int64_t i03 = i3 / sf3; - for (int64_t i2 = ith; i2 < ne2; i2 += nth) { - const int64_t i02 = i2 / sf2; - for (int64_t i1 = 0; i1 < ne1; i1++) { - const int64_t i01 = i1 / sf1; - for (int64_t i0 = 0; i0 < ne0; i0++) { - const int64_t i00 = i0 / sf0; + if (mode == GGML_SCALE_MODE_NEAREST) { + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i01 = i1 / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const int64_t i00 = i0 / sf0; - const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - *y = *x; + *y = *x; + } } } } + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True + const float pixel_offset = 0.5f; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset; + int64_t y0 = (int64_t)floorf(y); + int64_t y1 = y0 + 1; + + y0 = std::max(int64_t(0), std::min(y0, ne01 - 1)); + y1 = std::max(int64_t(0), std::min(y1, ne01 - 1)); + + float dy = y - (float)y0; + dy = std::max(0.0f, std::min(dy, 1.0f)); + + for (int64_t i0 = 0; i0 < ne0; i0++) { + const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset; + int64_t x0 = (int64_t)floorf(x); + int64_t x1 = x0 + 1; + + x0 = std::max(int64_t(0), std::min(x0, ne00 - 1)); + x1 = std::max(int64_t(0), std::min(x1, ne00 - 1)); + + float dx = x - (float)x0; + dx = std::max(0.0f, std::min(dx, 1.0f)); + + // fetch the four surrounding pixel values and interpolate + const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03); + const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03); + const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03); + const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03); + + const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy; + + float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + *y_dst = val; + } + } + } + } + } else { + GGML_ABORT("unsupported upscale mode"); } } @@ -8268,152 +8316,6 @@ void ggml_compute_forward_rwkv_wkv7( } } -// ggml_compute_forward_map_unary - -static void ggml_compute_forward_map_unary_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -void ggml_compute_forward_map_unary( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_unary_f32(params, dst, fun); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_binary - -static void ggml_compute_forward_map_binary_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(src1)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -void ggml_compute_forward_map_binary( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_binary_f32(params, dst, fun); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_custom1 - -void ggml_compute_forward_map_custom1_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom1_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - - if (params->ith != 0) { - return; - } - - fun(dst, a); -} - -// ggml_compute_forward_map_custom2 - -void ggml_compute_forward_map_custom2_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom2_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - const ggml_tensor * b = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b); -} - -// ggml_compute_forward_map_custom3 - -void ggml_compute_forward_map_custom3_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom3_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - const ggml_tensor * b = dst->src[1]; - const ggml_tensor * c = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b, c); -} - // ggml_compute_forward_map_custom1 void ggml_compute_forward_map_custom1( @@ -8459,6 +8361,18 @@ void ggml_compute_forward_map_custom3( p.fun(dst, a, b, c, params->ith, params->nth, p.userdata); } +// ggml_compute_forward_custom + +void ggml_compute_forward_custom( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + struct ggml_custom_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, params->ith, params->nth, p.userdata); +} + // ggml_compute_forward_cross_entropy_loss static void ggml_compute_forward_cross_entropy_loss_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index d43fbc1fc..410a37204 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -96,29 +96,10 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); -void ggml_compute_forward_map_unary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_unary_op_f32_t fun); -void ggml_compute_forward_map_binary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_binary_op_f32_t fun); -void ggml_compute_forward_map_custom1_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom1_op_f32_t fun); -void ggml_compute_forward_map_custom2_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom2_op_f32_t fun); -void ggml_compute_forward_map_custom3_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom3_op_f32_t fun); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index d7db9209f..04d10cec2 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -855,13 +855,17 @@ static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) { tmp[i] = GGML_FP16_TO_FP32(x[i]); } - return vec_xl(0, tmp); + // note: keep type-cast here to prevent compiler bugs + // see: https://github.com/ggml-org/llama.cpp/issues/12846 + return vec_xl(0, (const float *)(tmp)); } static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) { float arr[4]; - vec_xst(y, 0, arr); + // note: keep type-cast here to prevent compiler bugs + // see: https://github.com/ggml-org/llama.cpp/issues/12846 + vec_xst(y, 0, (float *)(arr)); for (int i = 0; i < 4; i++) { x[i] = GGML_FP32_TO_FP16(arr[i]); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5f29eba74..7ec4019ab 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3221,6 +3221,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_PAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index caa6b9dba..a19cfb14e 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -16,6 +16,14 @@ #include #endif // __ARM_FEATURE_SVE +#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__) +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include +#endif + #if defined(__F16C__) #include #endif @@ -140,8 +148,14 @@ struct ggml_map_custom2_op_params { struct ggml_map_custom3_op_params { ggml_custom3_op_t fun; - int n_tasks; - void * userdata; + int n_tasks; + void * userdata; +}; + +struct ggml_custom_op_params { + ggml_custom_op_t fun; + int n_tasks; + void * userdata; }; // bitset @@ -311,13 +325,6 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); // for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843 // #if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__) - - // if YCM cannot find , make a symbolic link to it, for example: - // - // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ - // - #include - #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 44d83430a..f6dd5d2af 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1334,8 +1334,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: return false; - case GGML_OP_POOL_2D: case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_POOL_2D: case GGML_OP_PAD: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 0423305bb..b36aa7a9d 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -1,4 +1,5 @@ #include "common.hpp" +#include "ggml.h" #include "element_wise.hpp" static void acc_f32(const float * x, const float * y, float * dst, const int ne, @@ -20,10 +21,11 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne, } } -static void gelu_f32(const float * x, float * dst, const int k, +template +static void gelu(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const T GELU_COEF_A = static_cast(0.044715f); + const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -32,12 +34,13 @@ static void gelu_f32(const float * x, float * dst, const int k, } float xi = x[i]; - dst[i] = 0.5f * xi * - (1.0f + - sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi))); + dst[i] = static_cast(0.5f) * xi * + (static_cast(1.0f) + + sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast(1.0f) + GELU_COEF_A * xi * xi))); } -static void silu_f32(const float * x, float * dst, const int k, +template +static void silu(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -45,10 +48,11 @@ static void silu_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i])); + dst[i] = x[i] / (static_cast(1.0f) + sycl::native::exp(-x[i])); } -static void gelu_quick_f32(const float *x, float *dst, int k, +template +static void gelu_quick(const T *x, T *dst, int k, const sycl::nd_item<3> &item_ct1) { const float GELU_QUICK_COEF = -1.702f; const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + @@ -56,20 +60,22 @@ static void gelu_quick_f32(const float *x, float *dst, int k, if (i >= k) { return; } - dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i]))); + dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i]))); } -static void tanh_f32(const float *x, float *dst, int k, +template +static void tanh(const T *x, T *dst, int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); if (i >= k) { return; } - dst[i] = sycl::tanh((float)(x[i])); + dst[i] = sycl::tanh((x[i])); } -static void relu_f32(const float * x, float * dst, const int k, +template +static void relu(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -77,10 +83,11 @@ static void relu_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = sycl::fmax((float)(x[i]), (float)0); + dst[i] = sycl::fmax((x[i]), static_cast(0)); } -static void sigmoid_f32(const float * x, float * dst, const int k, +template +static void sigmoid(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -88,10 +95,11 @@ static void sigmoid_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i])); + dst[i] = 1.0f / (static_cast(1.0f) + sycl::native::exp(-x[i])); } -static void sqrt_f32(const float * x, float * dst, const int k, +template +static void sqrt(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -102,7 +110,8 @@ static void sqrt_f32(const float * x, float * dst, const int k, dst[i] = sycl::sqrt(x[i]); } -static void sin_f32(const float * x, float * dst, const int k, +template +static void sin(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -113,7 +122,8 @@ static void sin_f32(const float * x, float * dst, const int k, dst[i] = sycl::sin(x[i]); } -static void cos_f32(const float * x, float * dst, const int k, +template +static void cos(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -124,7 +134,8 @@ static void cos_f32(const float * x, float * dst, const int k, dst[i] = sycl::cos(x[i]); } -static void hardsigmoid_f32(const float * x, float * dst, const int k, +template +static void hardsigmoid(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -132,10 +143,11 @@ static void hardsigmoid_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); + dst[i] = sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } -static void hardswish_f32(const float * x, float * dst, const int k, +template +static void hardswish(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -143,10 +155,11 @@ static void hardswish_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); + dst[i] = x[i] * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } -static void exp_f32(const float * x, float * dst, const int k, +template +static void exp(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -157,7 +170,8 @@ static void exp_f32(const float * x, float * dst, const int k, dst[i] = sycl::exp(x[i]); } -static void log_f32(const float * x, float * dst, const int k, +template +static void log(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -165,15 +179,16 @@ static void log_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - float xi = x[i]; + T xi = x[i]; if (xi <= 0) { - dst[i] = -INFINITY; + dst[i] = neg_infinity(); } else { dst[i] = sycl::log(xi); } } -static void neg_f32(const float * x, float * dst, const int k, +template +static void neg(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -184,7 +199,8 @@ static void neg_f32(const float * x, float * dst, const int k, dst[i] = -x[i]; } -static void step_f32(const float * x, float * dst, const int k, +template +static void step(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -192,21 +208,23 @@ static void step_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = x[i] > 0.0f; + dst[i] = x[i] > static_cast(0.0f); } -static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope, +template +static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); if (i >= k) { return; } - dst[i] = sycl::fmax((float)(x[i]), (float)0) + - sycl::fmin((float)(x[i]), 0.0f) * negative_slope; + dst[i] = sycl::fmax((x[i]), static_cast(0)) + + sycl::fmin((x[i]), static_cast(0.0f)) * negative_slope; } -static void sqr_f32(const float * x, float * dst, const int k, +template +static void sqr(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -217,7 +235,8 @@ static void sqr_f32(const float * x, float * dst, const int k, dst[i] = x[i] * x[i]; } -static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, +template +static void upscale(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { @@ -237,10 +256,11 @@ static void upscale_f32(const float *x, float *dst, const int nb00, const int n int i02 = i12 / sf2; int i03 = i13 / sf3; - dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); + dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } -static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02, +template +static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02, const sycl::nd_item<3> &item_ct1) { int nidx = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); @@ -256,11 +276,23 @@ static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, item_ct1.get_group(0) * ne00 * ne01; dst[offset_dst] = x[offset_src]; } else { - dst[offset_dst] = 0.0f; + dst[offset_dst] = static_cast(0.0f); } } +template +static void clamp(const T * x, T * dst, const float min, const float max, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + + dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); +} static void acc_f32_sycl(const float *x, const float *y, float *dst, const int n_elements, const int ne10, const int ne11, @@ -277,7 +309,8 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, }); } -static void gelu_f32_sycl(const float *x, float *dst, const int k, +template +static void gelu_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; stream->parallel_for( @@ -285,11 +318,12 @@ static void gelu_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - gelu_f32(x, dst, k, item_ct1); + gelu(x, dst, k, item_ct1); }); } -static void silu_f32_sycl(const float *x, float *dst, const int k, +template +static void silu_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; stream->parallel_for( @@ -297,11 +331,12 @@ static void silu_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - silu_f32(x, dst, k, item_ct1); + silu(x, dst, k, item_ct1); }); } -static void gelu_quick_f32_sycl(const float *x, float *dst, const int k, +template +static void gelu_quick_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; stream->parallel_for( @@ -309,11 +344,12 @@ static void gelu_quick_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - gelu_quick_f32(x, dst, k, item_ct1); + gelu_quick(x, dst, k, item_ct1); }); } -static void tanh_f32_sycl(const float *x, float *dst, const int k, +template +static void tanh_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; stream->parallel_for( @@ -321,11 +357,12 @@ static void tanh_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - tanh_f32(x, dst, k, item_ct1); + tanh(x, dst, k, item_ct1); }); } -static void relu_f32_sycl(const float *x, float *dst, const int k, +template +static void relu_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; stream->parallel_for( @@ -333,11 +370,12 @@ static void relu_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - relu_f32(x, dst, k, item_ct1); + relu(x, dst, k, item_ct1); }); } -static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, +template +static void hardsigmoid_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; stream->parallel_for( @@ -345,11 +383,12 @@ static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - hardsigmoid_f32(x, dst, k, item_ct1); + hardsigmoid(x, dst, k, item_ct1); }); } -static void hardswish_f32_sycl(const float *x, float *dst, const int k, +template +static void hardswish_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; stream->parallel_for( @@ -357,11 +396,12 @@ static void hardswish_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - hardswish_f32(x, dst, k, item_ct1); + hardswish(x, dst, k, item_ct1); }); } -static void exp_f32_sycl(const float *x, float *dst, const int k, +template +static void exp_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; stream->parallel_for( @@ -369,11 +409,12 @@ static void exp_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - exp_f32(x, dst, k, item_ct1); + exp(x, dst, k, item_ct1); }); } -static void log_f32_sycl(const float *x, float *dst, const int k, +template +static void log_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; stream->parallel_for( @@ -381,11 +422,12 @@ static void log_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - log_f32(x, dst, k, item_ct1); + log(x, dst, k, item_ct1); }); } -static void neg_f32_sycl(const float *x, float *dst, const int k, +template +static void neg_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; stream->parallel_for( @@ -393,11 +435,12 @@ static void neg_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - neg_f32(x, dst, k, item_ct1); + neg(x, dst, k, item_ct1); }); } -static void step_f32_sycl(const float *x, float *dst, const int k, +template +static void step_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; stream->parallel_for( @@ -405,11 +448,12 @@ static void step_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - step_f32(x, dst, k, item_ct1); + step(x, dst, k, item_ct1); }); } -static void sigmoid_f32_sycl(const float *x, float *dst, const int k, +template +static void sigmoid_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; stream->parallel_for( @@ -417,11 +461,12 @@ static void sigmoid_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - sigmoid_f32(x, dst, k, item_ct1); + sigmoid(x, dst, k, item_ct1); }); } -static void sqrt_f32_sycl(const float *x, float *dst, const int k, +template +static void sqrt_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; stream->parallel_for( @@ -429,11 +474,12 @@ static void sqrt_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - sqrt_f32(x, dst, k, item_ct1); + sqrt(x, dst, k, item_ct1); }); } -static void sin_f32_sycl(const float *x, float *dst, const int k, +template +static void sin_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; stream->parallel_for( @@ -441,11 +487,12 @@ static void sin_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - sin_f32(x, dst, k, item_ct1); + sin(x, dst, k, item_ct1); }); } -static void cos_f32_sycl(const float *x, float *dst, const int k, +template +static void cos_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; stream->parallel_for( @@ -453,11 +500,12 @@ static void cos_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - cos_f32(x, dst, k, item_ct1); + cos(x, dst, k, item_ct1); }); } -static void leaky_relu_f32_sycl(const float *x, float *dst, const int k, +template +static void leaky_relu_sycl(const T *x, T *dst, const int k, const float negative_slope, queue_ptr stream) { const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; @@ -466,11 +514,12 @@ static void leaky_relu_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - leaky_relu_f32(x, dst, k, negative_slope, item_ct1); + leaky_relu(x, dst, k, negative_slope, item_ct1); }); } -static void sqr_f32_sycl(const float *x, float *dst, const int k, +template +static void sqr_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; stream->parallel_for( @@ -478,11 +527,12 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - sqr_f32(x, dst, k, item_ct1); + sqr(x, dst, k, item_ct1); }); } -static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, +template +static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, queue_ptr stream) { @@ -492,11 +542,12 @@ static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const i stream->parallel_for( sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); + upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); }); } -static void pad_f32_sycl(const float *x, float *dst, const int ne00, +template +static void pad_sycl(const T *x, T *dst, const int ne00, const int ne01, const int ne02, const int ne0, const int ne1, const int ne2, queue_ptr stream) { int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; @@ -505,252 +556,688 @@ static void pad_f32_sycl(const float *x, float *dst, const int ne00, sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1); + pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); }); } -inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - - silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); +template +static void clamp_sycl(const T *x, T *dst, const float min, + const float max, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + clamp(x, dst, min, max, k, item_ct1); + }); } -inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } +} - gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); +inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - log_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - - sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - - sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } - -inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); float negative_slope; memcpy(&negative_slope, dst->op_params, sizeof(float)); - - leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream); -} - -inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + #if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } +} + +inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; + const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; + const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; + const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], + dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], + dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } +} + +inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } +} + +inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined(GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + float min; + float max; + memcpy(&min, dst->op_params, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - const float sf0 = (float)dst->ne[0]/dst->src[0]->ne[0]; - const float sf1 = (float)dst->ne[1]/dst->src[0]->ne[1]; - const float sf2 = (float)dst->ne[2]/dst->src[0]->ne[2]; - const float sf3 = (float)dst->ne[3]/dst->src[0]->ne[3]; - - upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); -} - -inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - pad_f32_sycl(src0_dd, dst_dd, - dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], main_stream); + switch (dst->type) { +#if defined(GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { @@ -773,6 +1260,7 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream); } + inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); @@ -795,126 +1283,133 @@ inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_sqrt(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_sin(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_cos(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_acc(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_gelu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_silu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_gelu_quick(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_tanh(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_relu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_sigmoid(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_hardsigmoid(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_hardswish(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_exp(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_log(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_neg(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_step(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_leaky_relu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_sqr(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_upscale(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_pad(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } +void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + ggml_sycl_op_clamp(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + + void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 464432645..76d91ba10 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -2,6 +2,13 @@ #define GGML_SYCL_ELEMENTWISE_HPP #include "common.hpp" +#include "ggml.h" +#include + +template +T neg_infinity() { + return -std::numeric_limits::infinity(); +} static __dpct_inline__ float op_repeat(const float a, const float b) { return b; @@ -24,6 +31,19 @@ static __dpct_inline__ float op_div(const float a, const float b) { return a / b; } +template +struct typed_data { + const T * src; + T * dst; +}; + +template +typed_data cast_data(ggml_tensor * dst) { + return { + /* .src = */ static_cast(dst->src[0]->data), + /* .dst = */ static_cast(dst->data) + }; +} void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst); @@ -65,6 +85,10 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +// --------- + void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 89715eaea..3e48a9244 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1617,17 +1617,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int dst[i] = scale * x[i]; } -static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - - dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); -} template static void pool2d_nchw_kernel( @@ -1768,18 +1757,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale, }); } -static void clamp_f32_sycl(const float *x, float *dst, const float min, - const float max, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - clamp_f32(x, dst, min, max, k, item_ct1); - }); -} static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, queue_ptr stream) { @@ -2258,26 +2235,6 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst SYCL_CHECK(0); } -inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - float min; - float max; - memcpy(&min, dst->op_params, sizeof(float)); - memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - - clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream()); - /* - DPCT1010:88: SYCL uses exceptions to report errors and does not use the - error codes. The call was replaced with 0. You need to rewrite this code. - */ - SYCL_CHECK(0); -} - static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { static bool peer_access_enabled = false; @@ -3218,10 +3175,6 @@ static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) ggml_sycl_op_scale(ctx, dst); } -static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_clamp(ctx, dst); -} - static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_diag_mask_inf(ctx, dst); } @@ -3700,7 +3653,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ #ifdef GGML_SYCL_GRAPH if (!g_ggml_sycl_disable_graph) { - if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) { + const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph); + if (!graph_support) { GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device); ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); return GGML_STATUS_SUCCESS; @@ -3711,8 +3665,10 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); model_sycl_graph.end_recording(); - if (!sycl_ctx->exec_graph) { - auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}}); + const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph); + if (!sycl_ctx->exec_graph || !graph_update_support) { + auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) : + model_sycl_graph.finalize(); sycl_ctx->exec_graph = std::make_unique< sycl_ex::command_graph>(exec_graph); } else { @@ -3900,7 +3856,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: - return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32); +#if defined (GGML_SYCL_F16) + return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type); +#else + return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); +#endif default: return false; } @@ -4022,13 +3982,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32); case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_LOG: - return (op->src[0]->type == GGML_TYPE_F32); +#if defined (GGML_SYCL_F16) + return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type)); +#else + return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); +#endif case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: @@ -4055,12 +4020,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: // TODO: add support for the new F32 operations return op->src[0]->type == GGML_TYPE_F16; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: - case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bac96d7e3..755bb51ca 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5766,7 +5766,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UPSCALE: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) { return ctx->device->pipeline_upscale_f32; } return nullptr; @@ -9421,9 +9421,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_COS: case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_UPSCALE: + return op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_ACC: case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: @@ -9791,7 +9792,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d44efeb32..564ee84ca 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -995,23 +995,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "UNARY", - "MAP_UNARY", - "MAP_BINARY", - - "MAP_CUSTOM1_F32", - "MAP_CUSTOM2_F32", - "MAP_CUSTOM3_F32", - "MAP_CUSTOM1", "MAP_CUSTOM2", "MAP_CUSTOM3", + "CUSTOM", + "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1094,23 +1089,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "unary(x)", - "f(x)", - "f(x,y)", - - "custom_f32(x)", - "custom_f32(x,y)", - "custom_f32(x,y,z)", + "map_custom(x)", + "map_custom(x,y)", + "map_custom(x,y,z)", "custom(x)", - "custom(x,y)", - "custom(x,y,z)", "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4197,7 +4187,8 @@ static struct ggml_tensor * ggml_upscale_impl( int ne0, int ne1, int ne2, - int ne3) { + int ne3, + enum ggml_scale_mode mode) { GGML_ASSERT(a->ne[0] <= ne0); GGML_ASSERT(a->ne[1] <= ne1); GGML_ASSERT(a->ne[2] <= ne2); @@ -4205,6 +4196,8 @@ static struct ggml_tensor * ggml_upscale_impl( struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + ggml_set_op_params_i32(result, 0, mode); + result->op = GGML_OP_UPSCALE; result->src[0] = a; @@ -4214,8 +4207,9 @@ static struct ggml_tensor * ggml_upscale_impl( struct ggml_tensor * ggml_upscale( struct ggml_context * ctx, struct ggml_tensor * a, - int scale_factor) { - return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); + int scale_factor, + enum ggml_scale_mode mode) { + return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode); } struct ggml_tensor * ggml_upscale_ext( @@ -4224,8 +4218,9 @@ struct ggml_tensor * ggml_upscale_ext( int ne0, int ne1, int ne2, - int ne3) { - return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); + int ne3, + enum ggml_scale_mode mode) { + return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode); } // ggml_pad @@ -4855,179 +4850,6 @@ struct ggml_tensor * ggml_unary_inplace( return ggml_unary_impl(ctx, a, op, true); } -// ggml_map_unary - -static struct ggml_tensor * ggml_map_unary_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_UNARY; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_unary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { - return ggml_map_unary_impl_f32(ctx, a, fun, false); -} - -struct ggml_tensor * ggml_map_unary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { - return ggml_map_unary_impl_f32(ctx, a, fun, true); -} - -// ggml_map_binary - -static struct ggml_tensor * ggml_map_binary_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun, - bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_BINARY; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_binary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { - return ggml_map_binary_impl_f32(ctx, a, b, fun, false); -} - -struct ggml_tensor * ggml_map_binary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { - return ggml_map_binary_impl_f32(ctx, a, b, fun, true); -} - -// ggml_map_custom1_f32 - -static struct ggml_tensor * ggml_map_custom1_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM1_F32; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_custom1_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun) { - return ggml_map_custom1_impl_f32(ctx, a, fun, false); -} - -struct ggml_tensor * ggml_map_custom1_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun) { - return ggml_map_custom1_impl_f32(ctx, a, fun, true); -} - -// ggml_map_custom2_f32 - -static struct ggml_tensor * ggml_map_custom2_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM2_F32; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_custom2_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun) { - return ggml_map_custom2_impl_f32(ctx, a, b, fun, false); -} - -struct ggml_tensor * ggml_map_custom2_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun) { - return ggml_map_custom2_impl_f32(ctx, a, b, fun, true); -} - -// ggml_map_custom3_f32 - -static struct ggml_tensor * ggml_map_custom3_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM3_F32; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -struct ggml_tensor * ggml_map_custom3_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun) { - return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); -} - -struct ggml_tensor * ggml_map_custom3_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun) { - return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); -} - // ggml_map_custom1 static struct ggml_tensor * ggml_map_custom1_impl( @@ -5046,7 +4868,7 @@ static struct ggml_tensor * ggml_map_custom1_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM1; result->src[0] = a; @@ -5091,7 +4913,7 @@ static struct ggml_tensor * ggml_map_custom2_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM2; result->src[0] = a; @@ -5140,7 +4962,7 @@ static struct ggml_tensor * ggml_map_custom3_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM3; result->src[0] = a; @@ -5172,6 +4994,66 @@ struct ggml_tensor * ggml_map_custom3_inplace( return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } +struct ggml_tensor * ggml_custom_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata) { + + GGML_ASSERT(n_args < GGML_MAX_SRC); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3); + + struct ggml_custom_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_CUSTOM; + for (int i = 0; i < n_args; i++) { + result->src[i] = args[i]; + } + + return result; +} + +struct ggml_tensor * ggml_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata) { + + GGML_ASSERT(n_args < GGML_MAX_SRC - 1); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_custom_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_CUSTOM; + result->src[0] = a; + for (int i = 0; i < n_args; i++) { + result->src[i + 1] = args[i]; + } + + return result; +} // ggml_cross_entropy_loss struct ggml_tensor * ggml_cross_entropy_loss( diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0410654dd..162070e6e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -280,6 +280,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK = auto() DEEPSEEK2 = auto() CHATGLM = auto() + GLM4 = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -487,6 +488,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.DEEPSEEK: "deepseek", MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -1561,6 +1563,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.GLM4 : [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_POST_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 50bef12e3..35154e9b5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -30,6 +30,7 @@ class TensorNameMap: "rwkv.embeddings", # rwkv6 "model.embeddings", # rwkv7 "model.word_embeddings", # bailingmoe + "language_model.model.embed_tokens", # llama4 ), # Token type embeddings @@ -67,6 +68,7 @@ class TensorNameMap: "output_layer", # chatglm "head", # rwkv "head.out", # wavtokenizer + "language_model.lm_head", # llama4 ), # Output norm @@ -89,6 +91,7 @@ class TensorNameMap: "rwkv.ln_out", # rwkv6 "model.ln_out", # rwkv7 "backbone.final_layer_norm", # wavtokenizer + "language_model.model.norm", # llama4 ), # Rope frequencies @@ -130,6 +133,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn_norm", # openelm "rwkv.blocks.{bid}.ln1", # rwkv6 "model.layers.{bid}.ln1", # rwkv7 + "language_model.model.layers.{bid}.input_layernorm", # llama4 ), # Attention norm 2 @@ -169,6 +173,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wq", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.h.{bid}.attn.attention.q_proj", # exaone + "language_model.model.layers.{bid}.self_attn.q_proj", # llama4 ), # Attention key @@ -183,6 +188,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.h.{bid}.attn.attention.k_proj", # exaone + "language_model.model.layers.{bid}.self_attn.k_proj", # llama4 ), # Attention value @@ -196,6 +202,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wv", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok "transformer.h.{bid}.attn.attention.v_proj", # exaone + "language_model.model.layers.{bid}.self_attn.v_proj", # llama4 ), # Attention output @@ -222,6 +229,7 @@ class TensorNameMap: "encoder.layers.{bid}.self_attention.dense", # chatglm "transformer.layers.{bid}.attn.out_proj", # openelm "transformer.h.{bid}.attn.attention.out_proj", # exaone + "language_model.model.layers.{bid}.self_attn.o_proj", # llama4 ), # Attention output norm @@ -233,7 +241,8 @@ class TensorNameMap: ), MODEL_TENSOR.ATTN_POST_NORM: ( - "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 + "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge + "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 ), # Rotary embeddings @@ -259,6 +268,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm + "language_model.model.layers.{bid}.post_attention_layernorm", # llama4 ), # Post feed-forward norm @@ -269,6 +279,7 @@ class TensorNameMap: # Post feed-forward norm MODEL_TENSOR.FFN_POST_NORM: ( "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 + "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -278,6 +289,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe + "language_model.model.layers.{bid}.feed_forward.router", # llama4 ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -306,7 +318,7 @@ class TensorNameMap: "h.{bid}.mlp.c_fc", # gpt2 "transformer.h.{bid}.mlp.fc1", # phi2 "model.layers.{bid}.mlp.fc1", # phi2 - "model.layers.{bid}.mlp.gate_up_proj", # phi3 + "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414 "model.layers.layers.{bid}.mlp.up_proj", # plamo "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert @@ -315,6 +327,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone + "language_model.model.layers.{bid}.feed_forward.up_proj", # llama4 ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -323,11 +336,13 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) + "language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4 ), MODEL_TENSOR.FFN_UP_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 + "language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 ), # AWQ-activation gate @@ -348,6 +363,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone + "language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4 ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -356,11 +372,13 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) + "language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 ), MODEL_TENSOR.FFN_GATE_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 + "language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 ), # Feed-forward down @@ -389,6 +407,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone + "language_model.model.layers.{bid}.feed_forward.down_proj", # llama4 ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -398,11 +417,13 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) + "language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4 ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 + "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 ), MODEL_TENSOR.ATTN_Q_NORM: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 264f8c5b9..a6fddc7fd 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -54,6 +54,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -1152,6 +1153,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_GLM4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_BITNET, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 201935281..2c2099b3c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -58,6 +58,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, + LLM_ARCH_GLM4, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -256,6 +257,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_POST_ATTN_NORM, + LLM_TENSOR_POST_MLP_NORM, LLM_TENSOR_SSM_IN, LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a7fdb2923..1194d5f46 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1210,6 +1210,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_9B; break; + case 61: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3571,6 +3580,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); } } break; + case LLM_ARCH_GLM4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4543,8 +4591,8 @@ struct llm_build_llama : public llm_graph_context { if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { // Llama4TextL2Norm - Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6); - Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6); + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); cb(Qcur, "Qcur_normed", il); cb(Kcur, "Kcur_normed", il); } @@ -10957,6 +11005,157 @@ struct llm_build_chatglm : public llm_graph_context { } }; +struct llm_build_glm4 : public llm_graph_context { + llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Post-attention norm (new!) + cur = build_norm(cur, + model.layers[il].attn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = build_norm(cur, + model.layers[il].ffn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_mlp_norm", il); + } + + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + // Final norm + cur = build_norm(inpL, + model.output_norm, + NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -12838,6 +13037,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GLM4: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params, gf); @@ -13035,6 +13238,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: + case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index bd2180b51..5d931e88c 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1807,6 +1807,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_PORO; clean_spaces = false; } else if ( + tokenizer_pre == "glm4" || tokenizer_pre == "chatglm-bpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; special_bos_id = LLAMA_TOKEN_NULL;