diff --git a/common/arg.cpp b/common/arg.cpp index a28a5261a..adad49eac 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -219,6 +219,53 @@ struct common_hf_file_res { std::string mmprojFile; }; +static void write_etag(const std::string & path, const std::string & etag) { + const std::string etag_path = path + ".etag"; + write_file(etag_path, etag); + LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str()); +} + +static std::string read_etag(const std::string & path) { + std::string none; + const std::string etag_path = path + ".etag"; + + if (std::filesystem::exists(etag_path)) { + std::ifstream etag_in(etag_path); + if (!etag_in) { + LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); + return none; + } + std::string etag; + std::getline(etag_in, etag); + return etag; + } + + // no etag file, but maybe there is an old .json + // remove this code later + const std::string metadata_path = path + ".json"; + + if (std::filesystem::exists(metadata_path)) { + std::ifstream metadata_in(metadata_path); + try { + nlohmann::json metadata_json; + metadata_in >> metadata_json; + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), + metadata_json.dump().c_str()); + if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { + std::string etag = metadata_json.at("etag"); + write_etag(path, etag); + if (!std::filesystem::remove(metadata_path)) { + LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str()); + } + return etag; + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + } + } + return none; +} + #ifdef LLAMA_USE_CURL bool common_has_curl() { @@ -375,36 +422,15 @@ static bool common_download_head(CURL * curl, static bool common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token) { - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; static const int max_attempts = 3; static const int retry_delay_seconds = 2; for (int i = 0; i < max_attempts; ++i) { - nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead - std::string etag; - std::string last_modified; + std::string etag; // Check if the file already exists locally const auto file_exists = std::filesystem::exists(path); if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), - metadata.dump().c_str()); - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - } - } - // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) + etag = read_etag(path); } else { LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } @@ -442,11 +468,6 @@ static bool common_download_file_single_online(const std::string & url, headers.etag.c_str()); should_download = true; should_download_from_scratch = true; - } else if (!last_modified.empty() && last_modified != headers.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, - last_modified.c_str(), headers.last_modified.c_str()); - should_download = true; - should_download_from_scratch = true; } } @@ -477,15 +498,9 @@ static bool common_download_file_single_online(const std::string & url, } } } - - // Write the updated JSON metadata file. - metadata.update({ - { "url", url }, - { "etag", headers.etag }, - { "lastModified", headers.last_modified } - }); - write_file(metadata_path, metadata.dump(4)); - LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + if (head_request_ok) { + write_etag(path, headers.etag); + } // start the download LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", @@ -668,51 +683,6 @@ static void print_progress(size_t current, size_t total) { // TODO isatty std::cout.flush(); } -struct common_file_metadata { - std::string etag; - std::string last_modified; -}; - -static std::optional read_metadata(const std::string & path) { - if (!std::filesystem::exists(path)) { - return std::nullopt; - } - - nlohmann::json metadata_json; - common_file_metadata metadata; - - std::ifstream metadata_in(path); - try { - metadata_in >> metadata_json; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, path.c_str(), - metadata_json.dump().c_str()); - if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { - metadata.etag = metadata_json.at("etag"); - } - if (metadata_json.contains("lastModified") && metadata_json.at("lastModified").is_string()) { - metadata.last_modified = metadata_json.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, path.c_str(), e.what()); - return std::nullopt; - } - - return metadata; -} - -static void write_metadata(const std::string & path, - const std::string & url, - const common_file_metadata & metadata) { - nlohmann::json metadata_json = { - { "url", url }, - { "etag", metadata.etag }, - { "lastModified", metadata.last_modified } - }; - - write_file(path, metadata_json.dump(4)); - LOG_DBG("%s: file metadata saved: %s\n", __func__, path.c_str()); -} - static bool common_pull_file(httplib::Client & cli, const std::string & resolve_path, const std::string & path_tmp, @@ -779,8 +749,6 @@ static bool common_pull_file(httplib::Client & cli, static bool common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token) { - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; static const int max_attempts = 3; static const int retry_delay_seconds = 2; @@ -792,12 +760,11 @@ static bool common_download_file_single_online(const std::string & url, } cli.set_default_headers(default_headers); - common_file_metadata last; const bool file_exists = std::filesystem::exists(path); + + std::string last_etag; if (file_exists) { - if (auto opt = read_metadata(metadata_path)) { - last = *opt; - } + last_etag = read_etag(path); } else { LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } @@ -813,14 +780,9 @@ static bool common_download_file_single_online(const std::string & url, } } - common_file_metadata current; - if (head_ok) { - if (head->has_header("ETag")) { - current.etag = head->get_header_value("ETag"); - } - if (head->has_header("Last-Modified")) { - current.last_modified = head->get_header_value("Last-Modified"); - } + std::string etag; + if (head_ok && head->has_header("ETag")) { + etag = head->get_header_value("ETag"); } size_t total_size = 0; @@ -838,16 +800,10 @@ static bool common_download_file_single_online(const std::string & url, } bool should_download_from_scratch = false; - if (head_ok) { - if (!last.etag.empty() && last.etag != current.etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, - last.etag.c_str(), current.etag.c_str()); - should_download_from_scratch = true; - } else if (!last.last_modified.empty() && last.last_modified != current.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, - last.last_modified.c_str(), current.last_modified.c_str()); - should_download_from_scratch = true; - } + if (!last_etag.empty() && !etag.empty() && last_etag != etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, + last_etag.c_str(), etag.c_str()); + should_download_from_scratch = true; } if (file_exists) { @@ -875,9 +831,8 @@ static bool common_download_file_single_online(const std::string & url, } // start the download - LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", - __func__, show_masked_url(parts).c_str(), path_temporary.c_str(), - current.etag.c_str(), current.last_modified.c_str()); + LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", + __func__, show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); if (!was_pull_successful) { if (i + 1 < max_attempts) { @@ -887,7 +842,6 @@ static bool common_download_file_single_online(const std::string & url, } else { LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); } - continue; } @@ -895,7 +849,9 @@ static bool common_download_file_single_online(const std::string & url, LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return false; } - write_metadata(metadata_path, url, current); + if (!etag.empty()) { + write_etag(path, etag); + } break; } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 1b763a628..746f43966 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -329,7 +329,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + if (src0->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -400,7 +404,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - return nullptr; + // Prioritize CUDA graph compatibility over direct memory copy optimization. + // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. + if (src0->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; + } else { + return nullptr; + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3efed50d4..979bf8ead 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2654,6 +2654,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; + const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2682,7 +2684,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && + strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // by means of matching node names. See // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0bf7fe9f9..819f31c8a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ case GGML_TYPE_F16: case GGML_TYPE_BF16: { - if (ne00 == 4) { + if (ne00 < 32) { nsg = 1; nr0 = 32; - nr1 = 4; - suffix = "_c4"; - } else if (ne00 % 4 == 0) { - nsg = N_SG_F; - nr0 = N_R0_F; nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - suffix = "_4"; + suffix = "_short"; } else { - nsg = N_SG_F; - nr0 = N_R0_F; + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; case GGML_TYPE_Q4_0: @@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra case GGML_TYPE_F16: case GGML_TYPE_BF16: { - if (ne00 % 4 == 0) { - nsg = N_SG_F; - nr0 = N_R0_F; - nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - suffix = "_4"; - } else { - nsg = N_SG_F; - nr0 = N_R0_F; - nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - } + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; + nr1 = 1; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; } break; case GGML_TYPE_Q4_0: { diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index d355c6dfc..88c98423e 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,9 +8,6 @@ // // TODO: for optimal performance, become function of the device and work size -#define N_R0_F 2 -#define N_SG_F 4 - #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 @@ -352,6 +349,7 @@ typedef struct { uint64_t nb13; int32_t ne0; int32_t ne1; + int32_t nr0; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; @@ -427,6 +425,7 @@ typedef struct { int32_t ne0; int32_t ne1; uint64_t nb1; + int32_t nr0; } ggml_metal_kargs_mul_mv_id; // NORM diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index d7267a6ae..e85a223c0 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { } else { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + ggml_metal_kargs_mul_mv args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { /*.nb13 =*/ nb13, /*.ne0 =*/ ne0, /*.ne1 =*/ ne1, + /*.nr0 =*/ nr0, /*.r2 =*/ r2, /*.r3 =*/ r3, }; - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); @@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); } } else { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + ggml_metal_kargs_mul_mv_id args = { /*.nei0 =*/ ne20, /*.nei1 =*/ ne21, @@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { /*.ne0 =*/ ne0, /*.ne1 =*/ ne1, /*.nb1 =*/ nb1, + /*.nr0 =*/ nr0, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); - - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - if (ggml_is_quantized(op->src[0]->type)) { GGML_ASSERT(ne00 >= nsg*nr0); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0271fd5b2..96df6f0ce 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3531,7 +3531,25 @@ void kernel_mul_mv_t_t_impl( helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template +template +void kernel_mul_mv_t_t_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + } +} + +template kernel void kernel_mul_mv_t_t( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3541,17 +3559,17 @@ kernel void kernel_mul_mv_t_t( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_t_t_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; +typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; -template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #endif template @@ -3637,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl( helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template +template +void kernel_mul_mv_t_t_4_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + }; +} + +template kernel void kernel_mul_mv_t_t_4( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3647,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_t_t_4_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; +typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #endif -#define N_MV_T_T 4 - -template -void kernel_mul_mv_c4_impl( +template +void kernel_mul_mv_t_t_short_impl( args_t args, device const char * src0, device const char * src1, @@ -3671,7 +3705,7 @@ void kernel_mul_mv_c4_impl( uint3 tgpig, ushort tiisg) { const int r0 = tgpig.x*32 + tiisg; - const int rb = tgpig.y*N_MV_T_T; + const int r1 = tgpig.y; const int im = tgpig.z; if (r0 >= args.ne01) { @@ -3683,33 +3717,32 @@ void kernel_mul_mv_c4_impl( const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - device const T04 * x = (device const T04 *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y = (device const T14 *) (src1 + offset1); + float res = 0.0f; - dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]); + for (int i = 0; i < args.ne00; ++i) { + res += (float) x[i] * (float) y[i]; } + + dst_f32[(uint64_t)r1*args.ne0 + r0] = res; } -template -kernel void kernel_mul_mv_c4( +template +kernel void kernel_mul_mv_t_t_short( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_c4_impl( + kernel_mul_mv_t_t_short_impl( args, src0, src1, @@ -3718,14 +3751,14 @@ kernel void kernel_mul_mv_c4( tiisg); } -typedef decltype(kernel_mul_mv_c4) mul_mv_c4_t; +typedef decltype(kernel_mul_mv_t_t_short) mul_mv_t_t_short_t; -template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #endif static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -8458,7 +8491,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_m // matrix-vector multiplication // -typedef void (kernel_mul_mv_impl_t)( +typedef void (kernel_mul_mv_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -8466,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)( uint3 tgpig, ushort tiisg); -typedef void (kernel_mul_mv2_impl_t)( +typedef void (kernel_mul_mv2_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -8476,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)( ushort tiisg, ushort sgitg); -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -8487,10 +8520,10 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, tgpig, tiisg); + disp_fn(args, src0, src1, dst, tgpig, tiisg); } -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -8501,12 +8534,12 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_disp_fn_t; -template +template kernel void kernel_mul_mv_id( constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, @@ -8553,11 +8586,12 @@ kernel void kernel_mul_mv_id( /*.nb13 =*/ args.nb12, // ne12 == 1 /*.ne0 =*/ args.ne0, /*.ne1 =*/ 1, // args.ne1, + /*.nr0 =*/ args.nr0, /*.r2 =*/ 1, /*.r3 =*/ 1, }; - impl_fn( + disp_fn( args0, /* src0 */ src0_cur, /* src1 */ src1_cur, @@ -8569,19 +8603,19 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #endif template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3c5f90724..dc90b5f9c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11813,6 +11813,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba2_y_add_d", il); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm @@ -14767,6 +14768,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + ggml_build_forward_expand(gf, inpL); auto * inp = build_inp_mem_hybrid(); @@ -14798,7 +14800,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { // add residual cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "block_out", il); + cb(cur, "nemotron_h_block_out", il); // input for next layer inpL = cur; diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index c1e6841d3..4575db67a 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/lib/utils/thinking.ts b/tools/server/webui/src/lib/utils/thinking.ts index 11ce87123..bed13fcec 100644 --- a/tools/server/webui/src/lib/utils/thinking.ts +++ b/tools/server/webui/src/lib/utils/thinking.ts @@ -1,7 +1,8 @@ /** - * Parses thinking content from a message that may contain tags + * Parses thinking content from a message that may contain tags or [THINK] tags * Returns an object with thinking content and cleaned message content - * Handles both complete ... blocks and incomplete blocks (streaming) + * Handles both complete blocks and incomplete blocks (streaming) + * Supports formats: ... and [THINK]...[/THINK] * @param content - The message content to parse * @returns An object containing the extracted thinking content and the cleaned message content */ @@ -9,12 +10,11 @@ export function parseThinkingContent(content: string): { thinking: string | null; cleanContent: string; } { - const incompleteMatch = content.includes('') && !content.includes(''); + const incompleteThinkMatch = content.includes('') && !content.includes(''); + const incompleteThinkBracketMatch = content.includes('[THINK]') && !content.includes('[/THINK]'); - if (incompleteMatch) { - // Remove the entire ... part from clean content + if (incompleteThinkMatch) { const cleanContent = content.split('')?.[1]?.trim(); - // Extract everything after as thinking content const thinkingContent = content.split('')?.[1]?.trim(); return { @@ -23,12 +23,40 @@ export function parseThinkingContent(content: string): { }; } - const completeMatch = content.includes(''); + if (incompleteThinkBracketMatch) { + const cleanContent = content.split('[/THINK]')?.[1]?.trim(); + const thinkingContent = content.split('[THINK]')?.[1]?.trim(); - if (completeMatch) { return { - thinking: content.split('')?.[0]?.trim(), - cleanContent: content.split('')?.[1]?.trim() + cleanContent, + thinking: thinkingContent + }; + } + + const completeThinkMatch = content.match(/([\s\S]*?)<\/think>/); + const completeThinkBracketMatch = content.match(/\[THINK\]([\s\S]*?)\[\/THINK\]/); + + if (completeThinkMatch) { + const thinkingContent = completeThinkMatch[1]?.trim() ?? ''; + const cleanContent = `${content.slice(0, completeThinkMatch.index ?? 0)}${content.slice( + (completeThinkMatch.index ?? 0) + completeThinkMatch[0].length + )}`.trim(); + + return { + thinking: thinkingContent, + cleanContent + }; + } + + if (completeThinkBracketMatch) { + const thinkingContent = completeThinkBracketMatch[1]?.trim() ?? ''; + const cleanContent = `${content.slice(0, completeThinkBracketMatch.index ?? 0)}${content.slice( + (completeThinkBracketMatch.index ?? 0) + completeThinkBracketMatch[0].length + )}`.trim(); + + return { + thinking: thinkingContent, + cleanContent }; } @@ -39,26 +67,33 @@ export function parseThinkingContent(content: string): { } /** - * Checks if content contains an opening tag (for streaming) + * Checks if content contains an opening thinking tag (for streaming) + * Supports both and [THINK] formats * @param content - The message content to check - * @returns True if the content contains an opening tag + * @returns True if the content contains an opening thinking tag */ export function hasThinkingStart(content: string): boolean { - return content.includes('') || content.includes('<|channel|>analysis'); + return ( + content.includes('') || + content.includes('[THINK]') || + content.includes('<|channel|>analysis') + ); } /** - * Checks if content contains a closing tag (for streaming) + * Checks if content contains a closing thinking tag (for streaming) + * Supports both and [/THINK] formats * @param content - The message content to check - * @returns True if the content contains a closing tag + * @returns True if the content contains a closing thinking tag */ export function hasThinkingEnd(content: string): boolean { - return content.includes(''); + return content.includes('') || content.includes('[/THINK]'); } /** * Extracts partial thinking content during streaming - * Used when we have but not yet + * Supports both and [THINK] formats + * Used when we have opening tag but not yet closing tag * @param content - The message content to extract partial thinking from * @returns An object containing the extracted partial thinking content and the remaining content */ @@ -66,23 +101,41 @@ export function extractPartialThinking(content: string): { thinking: string | null; remainingContent: string; } { - const startIndex = content.indexOf(''); - if (startIndex === -1) { + const thinkStartIndex = content.indexOf(''); + const thinkEndIndex = content.indexOf(''); + + const bracketStartIndex = content.indexOf('[THINK]'); + const bracketEndIndex = content.indexOf('[/THINK]'); + + const useThinkFormat = + thinkStartIndex !== -1 && (bracketStartIndex === -1 || thinkStartIndex < bracketStartIndex); + const useBracketFormat = + bracketStartIndex !== -1 && (thinkStartIndex === -1 || bracketStartIndex < thinkStartIndex); + + if (useThinkFormat) { + if (thinkEndIndex === -1) { + const thinkingStart = thinkStartIndex + ''.length; + + return { + thinking: content.substring(thinkingStart), + remainingContent: content.substring(0, thinkStartIndex) + }; + } + } else if (useBracketFormat) { + if (bracketEndIndex === -1) { + const thinkingStart = bracketStartIndex + '[THINK]'.length; + + return { + thinking: content.substring(thinkingStart), + remainingContent: content.substring(0, bracketStartIndex) + }; + } + } else { return { thinking: null, remainingContent: content }; } - const endIndex = content.indexOf(''); - if (endIndex === -1) { - // Still streaming thinking content - const thinkingStart = startIndex + ''.length; - return { - thinking: content.substring(thinkingStart), - remainingContent: content.substring(0, startIndex) - }; - } - - // Complete thinking block found const parsed = parseThinkingContent(content); + return { thinking: parsed.thinking, remainingContent: parsed.cleanContent diff --git a/tools/server/webui/src/stories/ChatMessage.stories.svelte b/tools/server/webui/src/stories/ChatMessage.stories.svelte index f9d7d5358..c6377e23c 100644 --- a/tools/server/webui/src/stories/ChatMessage.stories.svelte +++ b/tools/server/webui/src/stories/ChatMessage.stories.svelte @@ -59,6 +59,60 @@ thinking: '', children: [] }); + + // Message with format thinking content + const thinkTagMessage: DatabaseMessage = { + id: '6', + convId: 'conv-1', + type: 'message', + timestamp: Date.now() - 1000 * 60 * 2, + role: 'assistant', + content: + "\nLet me analyze this step by step:\n\n1. The user is asking about thinking formats\n2. I need to demonstrate the <think> tag format\n3. This content should be displayed in the thinking section\n4. The main response should be separate\n\nThis is a good example of reasoning content.\n\n\nHere's my response after thinking through the problem. The thinking content above should be displayed separately from this main response content.", + parent: '1', + thinking: '', + children: [] + }; + + // Message with [THINK] format thinking content + const thinkBracketMessage: DatabaseMessage = { + id: '7', + convId: 'conv-1', + type: 'message', + timestamp: Date.now() - 1000 * 60 * 1, + role: 'assistant', + content: + '[THINK]\nThis is the DeepSeek-style thinking format:\n\n- Using square brackets instead of angle brackets\n- Should work identically to the <think> format\n- Content parsing should extract this reasoning\n- Display should be the same as <think> format\n\nBoth formats should be supported seamlessly.\n[/THINK]\n\nThis is the main response content that comes after the [THINK] block. The reasoning above should be parsed and displayed in the thinking section.', + parent: '1', + thinking: '', + children: [] + }; + + // Streaming message for format + let streamingThinkMessage = $state({ + id: '8', + convId: 'conv-1', + type: 'message', + timestamp: 0, // No timestamp = streaming + role: 'assistant', + content: '', + parent: '1', + thinking: '', + children: [] + }); + + // Streaming message for [THINK] format + let streamingBracketMessage = $state({ + id: '9', + convId: 'conv-1', + type: 'message', + timestamp: 0, // No timestamp = streaming + role: 'assistant', + content: '', + parent: '1', + thinking: '', + children: [] + }); setTimeout(resolve, 100)); }} /> + + + + + + { + // Phase 1: Stream reasoning content + const thinkingContent = + 'Let me work through this problem systematically:\n\n1. First, I need to understand what the user is asking\n2. Then I should consider different approaches\n3. I need to evaluate the pros and cons\n4. Finally, I should provide a clear recommendation\n\nThis step-by-step approach will ensure accuracy.'; + + let currentContent = '\n'; + streamingThinkMessage.content = currentContent; + + for (let i = 0; i < thinkingContent.length; i++) { + currentContent += thinkingContent[i]; + streamingThinkMessage.content = currentContent; + await new Promise((resolve) => setTimeout(resolve, 5)); + } + + // Close the thinking block + currentContent += '\n\n\n'; + streamingThinkMessage.content = currentContent; + await new Promise((resolve) => setTimeout(resolve, 200)); + + // Phase 2: Stream main response content + const responseContent = + "Based on my analysis above, here's the solution:\n\n**Key Points:**\n- The approach should be systematic\n- We need to consider all factors\n- Implementation should be step-by-step\n\nThis ensures the best possible outcome."; + + for (let i = 0; i < responseContent.length; i++) { + currentContent += responseContent[i]; + streamingThinkMessage.content = currentContent; + await new Promise((resolve) => setTimeout(resolve, 10)); + } + + streamingThinkMessage.timestamp = Date.now(); + }} +> +
+ +
+
+ + { + // Phase 1: Stream [THINK] reasoning content + const thinkingContent = + 'Using the DeepSeek format now:\n\n- This demonstrates the [THINK] bracket format\n- Should parse identically to <think> tags\n- The UI should display this in the thinking section\n- Main content should be separate\n\nBoth formats provide the same functionality.'; + + let currentContent = '[THINK]\n'; + streamingBracketMessage.content = currentContent; + + for (let i = 0; i < thinkingContent.length; i++) { + currentContent += thinkingContent[i]; + streamingBracketMessage.content = currentContent; + await new Promise((resolve) => setTimeout(resolve, 5)); + } + + // Close the thinking block + currentContent += '\n[/THINK]\n\n'; + streamingBracketMessage.content = currentContent; + await new Promise((resolve) => setTimeout(resolve, 200)); + + // Phase 2: Stream main response content + const responseContent = + "Here's my response after using the [THINK] format:\n\n**Observations:**\n- Both <think> and [THINK] formats work seamlessly\n- The parsing logic handles both cases\n- UI display is consistent across formats\n\nThis demonstrates the enhanced thinking content support."; + + for (let i = 0; i < responseContent.length; i++) { + currentContent += responseContent[i]; + streamingBracketMessage.content = currentContent; + await new Promise((resolve) => setTimeout(resolve, 10)); + } + + streamingBracketMessage.timestamp = Date.now(); + }} +> +
+ +
+