Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	CODEOWNERS
#	ggml/CMakeLists.txt
#	ggml/src/ggml-cpu/CMakeLists.txt
#	ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-09-30 22:28:53 +08:00
commit 20c802a198
11 changed files with 455 additions and 241 deletions

View file

@ -219,6 +219,53 @@ struct common_hf_file_res {
std::string mmprojFile;
};
static void write_etag(const std::string & path, const std::string & etag) {
const std::string etag_path = path + ".etag";
write_file(etag_path, etag);
LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
}
static std::string read_etag(const std::string & path) {
std::string none;
const std::string etag_path = path + ".etag";
if (std::filesystem::exists(etag_path)) {
std::ifstream etag_in(etag_path);
if (!etag_in) {
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
return none;
}
std::string etag;
std::getline(etag_in, etag);
return etag;
}
// no etag file, but maybe there is an old .json
// remove this code later
const std::string metadata_path = path + ".json";
if (std::filesystem::exists(metadata_path)) {
std::ifstream metadata_in(metadata_path);
try {
nlohmann::json metadata_json;
metadata_in >> metadata_json;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata_json.dump().c_str());
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
std::string etag = metadata_json.at("etag");
write_etag(path, etag);
if (!std::filesystem::remove(metadata_path)) {
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
}
return etag;
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
}
return none;
}
#ifdef LLAMA_USE_CURL
bool common_has_curl() {
@ -375,36 +422,15 @@ static bool common_download_head(CURL * curl,
static bool common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token) {
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
for (int i = 0; i < max_attempts; ++i) {
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
std::string etag;
std::string last_modified;
std::string etag;
// Check if the file already exists locally
const auto file_exists = std::filesystem::exists(path);
if (file_exists) {
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
metadata.dump().c_str());
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
}
}
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
etag = read_etag(path);
} else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
@ -442,11 +468,6 @@ static bool common_download_file_single_online(const std::string & url,
headers.etag.c_str());
should_download = true;
should_download_from_scratch = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
should_download_from_scratch = true;
}
}
@ -477,15 +498,9 @@ static bool common_download_file_single_online(const std::string & url,
}
}
}
// Write the updated JSON metadata file.
metadata.update({
{ "url", url },
{ "etag", headers.etag },
{ "lastModified", headers.last_modified }
});
write_file(metadata_path, metadata.dump(4));
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
if (head_request_ok) {
write_etag(path, headers.etag);
}
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
@ -668,51 +683,6 @@ static void print_progress(size_t current, size_t total) { // TODO isatty
std::cout.flush();
}
struct common_file_metadata {
std::string etag;
std::string last_modified;
};
static std::optional<common_file_metadata> read_metadata(const std::string & path) {
if (!std::filesystem::exists(path)) {
return std::nullopt;
}
nlohmann::json metadata_json;
common_file_metadata metadata;
std::ifstream metadata_in(path);
try {
metadata_in >> metadata_json;
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, path.c_str(),
metadata_json.dump().c_str());
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
metadata.etag = metadata_json.at("etag");
}
if (metadata_json.contains("lastModified") && metadata_json.at("lastModified").is_string()) {
metadata.last_modified = metadata_json.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, path.c_str(), e.what());
return std::nullopt;
}
return metadata;
}
static void write_metadata(const std::string & path,
const std::string & url,
const common_file_metadata & metadata) {
nlohmann::json metadata_json = {
{ "url", url },
{ "etag", metadata.etag },
{ "lastModified", metadata.last_modified }
};
write_file(path, metadata_json.dump(4));
LOG_DBG("%s: file metadata saved: %s\n", __func__, path.c_str());
}
static bool common_pull_file(httplib::Client & cli,
const std::string & resolve_path,
const std::string & path_tmp,
@ -779,8 +749,6 @@ static bool common_pull_file(httplib::Client & cli,
static bool common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token) {
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
@ -792,12 +760,11 @@ static bool common_download_file_single_online(const std::string & url,
}
cli.set_default_headers(default_headers);
common_file_metadata last;
const bool file_exists = std::filesystem::exists(path);
std::string last_etag;
if (file_exists) {
if (auto opt = read_metadata(metadata_path)) {
last = *opt;
}
last_etag = read_etag(path);
} else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
@ -813,14 +780,9 @@ static bool common_download_file_single_online(const std::string & url,
}
}
common_file_metadata current;
if (head_ok) {
if (head->has_header("ETag")) {
current.etag = head->get_header_value("ETag");
}
if (head->has_header("Last-Modified")) {
current.last_modified = head->get_header_value("Last-Modified");
}
std::string etag;
if (head_ok && head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
size_t total_size = 0;
@ -838,16 +800,10 @@ static bool common_download_file_single_online(const std::string & url,
}
bool should_download_from_scratch = false;
if (head_ok) {
if (!last.etag.empty() && last.etag != current.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
last.etag.c_str(), current.etag.c_str());
should_download_from_scratch = true;
} else if (!last.last_modified.empty() && last.last_modified != current.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
last.last_modified.c_str(), current.last_modified.c_str());
should_download_from_scratch = true;
}
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
last_etag.c_str(), etag.c_str());
should_download_from_scratch = true;
}
if (file_exists) {
@ -875,9 +831,8 @@ static bool common_download_file_single_online(const std::string & url,
}
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
__func__, show_masked_url(parts).c_str(), path_temporary.c_str(),
current.etag.c_str(), current.last_modified.c_str());
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
__func__, show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
if (!was_pull_successful) {
if (i + 1 < max_attempts) {
@ -887,7 +842,6 @@ static bool common_download_file_single_online(const std::string & url,
} else {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
}
continue;
}
@ -895,7 +849,9 @@ static bool common_download_file_single_online(const std::string & url,
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
write_metadata(metadata_path, url, current);
if (!etag.empty()) {
write_etag(path, etag);
}
break;
}

View file

@ -329,7 +329,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
if (src0->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
@ -400,7 +404,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
return nullptr;
// Prioritize CUDA graph compatibility over direct memory copy optimization.
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
if (src0->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else {
return nullptr;
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {

View file

@ -2654,6 +2654,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@ -2682,7 +2684,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and

View file

@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
if (ne00 == 4) {
if (ne00 < 32) {
nsg = 1;
nr0 = 32;
nr1 = 4;
suffix = "_c4";
} else if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
suffix = "_4";
suffix = "_short";
} else {
nsg = N_SG_F;
nr0 = N_R0_F;
nsg = std::min(4, (ne00 + 127) / 128);
nr0 = 2;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
smem = 32*sizeof(float)*nr0;
suffix = ne00 % 4 == 0 ? "_4" : "";
}
} break;
case GGML_TYPE_Q4_0:
@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
if (ne00 % 4 == 0) {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
suffix = "_4";
} else {
nsg = N_SG_F;
nr0 = N_R0_F;
nr1 = 1;
smem = 32*sizeof(float)*N_R0_F;
}
nsg = std::min(4, (ne00 + 127) / 128);
nr0 = 2;
nr1 = 1;
smem = 32*sizeof(float)*nr0;
suffix = ne00 % 4 == 0 ? "_4" : "";
} break;
case GGML_TYPE_Q4_0:
{

View file

@ -8,9 +8,6 @@
//
// TODO: for optimal performance, become function of the device and work size
#define N_R0_F 2
#define N_SG_F 4
#define N_R0_Q4_0 4
#define N_SG_Q4_0 2
@ -352,6 +349,7 @@ typedef struct {
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t nr0;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mv;
@ -427,6 +425,7 @@ typedef struct {
int32_t ne0;
int32_t ne1;
uint64_t nb1;
int32_t nr0;
} ggml_metal_kargs_mul_mv_id;
// NORM

View file

@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
} else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_kargs_mul_mv args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nr0 =*/ nr0,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
}
} else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_kargs_mul_mv_id args = {
/*.nei0 =*/ ne20,
/*.nei1 =*/ ne21,
@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nb1 =*/ nb1,
/*.nr0 =*/ nr0,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
if (ggml_is_quantized(op->src[0]->type)) {
GGML_ASSERT(ne00 >= nsg*nr0);
}

View file

@ -3531,7 +3531,25 @@ void kernel_mul_mv_t_t_impl(
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
template<typename T0, typename T1, short NR0>
template<typename T0, typename T1, typename args_t>
void kernel_mul_mv_t_t_disp(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
switch (args.nr0) {
//case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
//case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
//case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
}
}
template<typename T0, typename T1>
kernel void kernel_mul_mv_t_t(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
@ -3541,17 +3559,17 @@ kernel void kernel_mul_mv_t_t(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t;
typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F>;
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float>;
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>;
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
#endif
template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
@ -3637,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl(
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
template<typename T0, typename T04, typename T1, typename T14, short NR0>
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
void kernel_mul_mv_t_t_4_disp(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
switch (args.nr0) {
//case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
//case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
//case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
};
}
template<typename T0, typename T04, typename T1, typename T14>
kernel void kernel_mul_mv_t_t_4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
@ -3647,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4;
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>;
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F>;
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>;
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4>;
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F>;
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>;
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4>;
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
#endif
#define N_MV_T_T 4
template<typename T04, typename T14, typename args_t>
void kernel_mul_mv_c4_impl(
template<typename T0, typename T1, typename args_t>
void kernel_mul_mv_t_t_short_impl(
args_t args,
device const char * src0,
device const char * src1,
@ -3671,7 +3705,7 @@ void kernel_mul_mv_c4_impl(
uint3 tgpig,
ushort tiisg) {
const int r0 = tgpig.x*32 + tiisg;
const int rb = tgpig.y*N_MV_T_T;
const int r1 = tgpig.y;
const int im = tgpig.z;
if (r0 >= args.ne01) {
@ -3683,33 +3717,32 @@ void kernel_mul_mv_c4_impl(
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
device const T04 * x = (device const T04 *) (src0 + offset0);
device const T0 * x = (device const T0 *) (src0 + offset0);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
if (r1 >= args.ne11) {
break;
}
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T1 * y = (device const T1 *) (src1 + offset1);
device const T14 * y = (device const T14 *) (src1 + offset1);
float res = 0.0f;
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
for (int i = 0; i < args.ne00; ++i) {
res += (float) x[i] * (float) y[i];
}
dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
}
template<typename T04, typename T14>
kernel void kernel_mul_mv_c4(
template<typename T0, typename T1>
kernel void kernel_mul_mv_t_t_short(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
args,
src0,
src1,
@ -3718,14 +3751,14 @@ kernel void kernel_mul_mv_c4(
tiisg);
}
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, half4>;
template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, float>;
template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>;
template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
#endif
static float rope_yarn_ramp(const float low, const float high, const int i0) {
@ -8458,7 +8491,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_m
// matrix-vector multiplication
//
typedef void (kernel_mul_mv_impl_t)(
typedef void (kernel_mul_mv_disp_t)(
ggml_metal_kargs_mul_mv args,
device const char * src0,
device const char * src1,
@ -8466,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)(
uint3 tgpig,
ushort tiisg);
typedef void (kernel_mul_mv2_impl_t)(
typedef void (kernel_mul_mv2_disp_t)(
ggml_metal_kargs_mul_mv args,
device const char * src0,
device const char * src1,
@ -8476,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)(
ushort tiisg,
ushort sgitg);
template<kernel_mul_mv_impl_t impl_fn>
template<kernel_mul_mv_disp_t disp_fn>
void mmv_fn(
ggml_metal_kargs_mul_mv args,
device const char * src0,
@ -8487,10 +8520,10 @@ void mmv_fn(
ushort tiitg,
ushort tiisg,
ushort sgitg) {
impl_fn(args, src0, src1, dst, tgpig, tiisg);
disp_fn(args, src0, src1, dst, tgpig, tiisg);
}
template<kernel_mul_mv2_impl_t impl_fn>
template<kernel_mul_mv2_disp_t disp_fn>
void mmv_fn(
ggml_metal_kargs_mul_mv args,
device const char * src0,
@ -8501,12 +8534,12 @@ void mmv_fn(
ushort tiitg,
ushort tiisg,
ushort sgitg) {
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
template<mul_mv_impl_fn_t impl_fn>
template<mul_mv_disp_fn_t disp_fn>
kernel void kernel_mul_mv_id(
constant ggml_metal_kargs_mul_mv_id & args,
device const char * src0s,
@ -8553,11 +8586,12 @@ kernel void kernel_mul_mv_id(
/*.nb13 =*/ args.nb12, // ne12 == 1
/*.ne0 =*/ args.ne0,
/*.ne1 =*/ 1, // args.ne1,
/*.nr0 =*/ args.nr0,
/*.r2 =*/ 1,
/*.r3 =*/ 1,
};
impl_fn(
disp_fn(
args0,
/* src0 */ src0_cur,
/* src1 */ src1_cur,
@ -8569,19 +8603,19 @@ kernel void kernel_mul_mv_id(
sgitg);
}
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t;
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F>>>;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half, float>>>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>;
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
#endif
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>;
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F>>>;
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half, half4, float, float4>>>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>;
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
#endif
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;

View file

@ -11813,6 +11813,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
// TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
cb(y, "mamba2_y_add_d", il);
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
// grouped RMS norm
@ -14767,6 +14768,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba {
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
ggml_build_forward_expand(gf, inpL);
auto * inp = build_inp_mem_hybrid();
@ -14798,7 +14800,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba {
// add residual
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "block_out", il);
cb(cur, "nemotron_h_block_out", il);
// input for next layer
inpL = cur;

Binary file not shown.

View file

@ -1,7 +1,8 @@
/**
* Parses thinking content from a message that may contain <think> tags
* Parses thinking content from a message that may contain <think> tags or [THINK] tags
* Returns an object with thinking content and cleaned message content
* Handles both complete <think>...</think> blocks and incomplete <think> blocks (streaming)
* Handles both complete blocks and incomplete blocks (streaming)
* Supports formats: <think>...</think> and [THINK]...[/THINK]
* @param content - The message content to parse
* @returns An object containing the extracted thinking content and the cleaned message content
*/
@ -9,12 +10,11 @@ export function parseThinkingContent(content: string): {
thinking: string | null;
cleanContent: string;
} {
const incompleteMatch = content.includes('<think>') && !content.includes('</think>');
const incompleteThinkMatch = content.includes('<think>') && !content.includes('</think>');
const incompleteThinkBracketMatch = content.includes('[THINK]') && !content.includes('[/THINK]');
if (incompleteMatch) {
// Remove the entire <think>... part from clean content
if (incompleteThinkMatch) {
const cleanContent = content.split('</think>')?.[1]?.trim();
// Extract everything after <think> as thinking content
const thinkingContent = content.split('<think>')?.[1]?.trim();
return {
@ -23,12 +23,40 @@ export function parseThinkingContent(content: string): {
};
}
const completeMatch = content.includes('</think>');
if (incompleteThinkBracketMatch) {
const cleanContent = content.split('[/THINK]')?.[1]?.trim();
const thinkingContent = content.split('[THINK]')?.[1]?.trim();
if (completeMatch) {
return {
thinking: content.split('</think>')?.[0]?.trim(),
cleanContent: content.split('</think>')?.[1]?.trim()
cleanContent,
thinking: thinkingContent
};
}
const completeThinkMatch = content.match(/<think>([\s\S]*?)<\/think>/);
const completeThinkBracketMatch = content.match(/\[THINK\]([\s\S]*?)\[\/THINK\]/);
if (completeThinkMatch) {
const thinkingContent = completeThinkMatch[1]?.trim() ?? '';
const cleanContent = `${content.slice(0, completeThinkMatch.index ?? 0)}${content.slice(
(completeThinkMatch.index ?? 0) + completeThinkMatch[0].length
)}`.trim();
return {
thinking: thinkingContent,
cleanContent
};
}
if (completeThinkBracketMatch) {
const thinkingContent = completeThinkBracketMatch[1]?.trim() ?? '';
const cleanContent = `${content.slice(0, completeThinkBracketMatch.index ?? 0)}${content.slice(
(completeThinkBracketMatch.index ?? 0) + completeThinkBracketMatch[0].length
)}`.trim();
return {
thinking: thinkingContent,
cleanContent
};
}
@ -39,26 +67,33 @@ export function parseThinkingContent(content: string): {
}
/**
* Checks if content contains an opening <think> tag (for streaming)
* Checks if content contains an opening thinking tag (for streaming)
* Supports both <think> and [THINK] formats
* @param content - The message content to check
* @returns True if the content contains an opening <think> tag
* @returns True if the content contains an opening thinking tag
*/
export function hasThinkingStart(content: string): boolean {
return content.includes('<think>') || content.includes('<|channel|>analysis');
return (
content.includes('<think>') ||
content.includes('[THINK]') ||
content.includes('<|channel|>analysis')
);
}
/**
* Checks if content contains a closing </think> tag (for streaming)
* Checks if content contains a closing thinking tag (for streaming)
* Supports both </think> and [/THINK] formats
* @param content - The message content to check
* @returns True if the content contains a closing </think> tag
* @returns True if the content contains a closing thinking tag
*/
export function hasThinkingEnd(content: string): boolean {
return content.includes('</think>');
return content.includes('</think>') || content.includes('[/THINK]');
}
/**
* Extracts partial thinking content during streaming
* Used when we have <think> but not yet </think>
* Supports both <think> and [THINK] formats
* Used when we have opening tag but not yet closing tag
* @param content - The message content to extract partial thinking from
* @returns An object containing the extracted partial thinking content and the remaining content
*/
@ -66,23 +101,41 @@ export function extractPartialThinking(content: string): {
thinking: string | null;
remainingContent: string;
} {
const startIndex = content.indexOf('<think>');
if (startIndex === -1) {
const thinkStartIndex = content.indexOf('<think>');
const thinkEndIndex = content.indexOf('</think>');
const bracketStartIndex = content.indexOf('[THINK]');
const bracketEndIndex = content.indexOf('[/THINK]');
const useThinkFormat =
thinkStartIndex !== -1 && (bracketStartIndex === -1 || thinkStartIndex < bracketStartIndex);
const useBracketFormat =
bracketStartIndex !== -1 && (thinkStartIndex === -1 || bracketStartIndex < thinkStartIndex);
if (useThinkFormat) {
if (thinkEndIndex === -1) {
const thinkingStart = thinkStartIndex + '<think>'.length;
return {
thinking: content.substring(thinkingStart),
remainingContent: content.substring(0, thinkStartIndex)
};
}
} else if (useBracketFormat) {
if (bracketEndIndex === -1) {
const thinkingStart = bracketStartIndex + '[THINK]'.length;
return {
thinking: content.substring(thinkingStart),
remainingContent: content.substring(0, bracketStartIndex)
};
}
} else {
return { thinking: null, remainingContent: content };
}
const endIndex = content.indexOf('</think>');
if (endIndex === -1) {
// Still streaming thinking content
const thinkingStart = startIndex + '<think>'.length;
return {
thinking: content.substring(thinkingStart),
remainingContent: content.substring(0, startIndex)
};
}
// Complete thinking block found
const parsed = parseThinkingContent(content);
return {
thinking: parsed.thinking,
remainingContent: parsed.cleanContent

View file

@ -59,6 +59,60 @@
thinking: '',
children: []
});
// Message with <think> format thinking content
const thinkTagMessage: DatabaseMessage = {
id: '6',
convId: 'conv-1',
type: 'message',
timestamp: Date.now() - 1000 * 60 * 2,
role: 'assistant',
content:
"<think>\nLet me analyze this step by step:\n\n1. The user is asking about thinking formats\n2. I need to demonstrate the &lt;think&gt; tag format\n3. This content should be displayed in the thinking section\n4. The main response should be separate\n\nThis is a good example of reasoning content.\n</think>\n\nHere's my response after thinking through the problem. The thinking content above should be displayed separately from this main response content.",
parent: '1',
thinking: '',
children: []
};
// Message with [THINK] format thinking content
const thinkBracketMessage: DatabaseMessage = {
id: '7',
convId: 'conv-1',
type: 'message',
timestamp: Date.now() - 1000 * 60 * 1,
role: 'assistant',
content:
'[THINK]\nThis is the DeepSeek-style thinking format:\n\n- Using square brackets instead of angle brackets\n- Should work identically to the &lt;think&gt; format\n- Content parsing should extract this reasoning\n- Display should be the same as &lt;think&gt; format\n\nBoth formats should be supported seamlessly.\n[/THINK]\n\nThis is the main response content that comes after the [THINK] block. The reasoning above should be parsed and displayed in the thinking section.',
parent: '1',
thinking: '',
children: []
};
// Streaming message for <think> format
let streamingThinkMessage = $state({
id: '8',
convId: 'conv-1',
type: 'message',
timestamp: 0, // No timestamp = streaming
role: 'assistant',
content: '',
parent: '1',
thinking: '',
children: []
});
// Streaming message for [THINK] format
let streamingBracketMessage = $state({
id: '9',
convId: 'conv-1',
type: 'message',
timestamp: 0, // No timestamp = streaming
role: 'assistant',
content: '',
parent: '1',
thinking: '',
children: []
});
</script>
<Story
@ -144,3 +198,115 @@
await new Promise(resolve => setTimeout(resolve, 100));
}}
/>
<Story
name="ThinkTagFormat"
args={{
class: 'max-w-[56rem] w-[calc(100vw-2rem)]',
message: thinkTagMessage
}}
/>
<Story
name="ThinkBracketFormat"
args={{
class: 'max-w-[56rem] w-[calc(100vw-2rem)]',
message: thinkBracketMessage
}}
/>
<Story
name="StreamingThinkTag"
args={{
message: streamingThinkMessage
}}
parameters={{
test: {
timeout: 30000
}
}}
asChild
play={async () => {
// Phase 1: Stream <think> reasoning content
const thinkingContent =
'Let me work through this problem systematically:\n\n1. First, I need to understand what the user is asking\n2. Then I should consider different approaches\n3. I need to evaluate the pros and cons\n4. Finally, I should provide a clear recommendation\n\nThis step-by-step approach will ensure accuracy.';
let currentContent = '<think>\n';
streamingThinkMessage.content = currentContent;
for (let i = 0; i < thinkingContent.length; i++) {
currentContent += thinkingContent[i];
streamingThinkMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 5));
}
// Close the thinking block
currentContent += '\n</think>\n\n';
streamingThinkMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 200));
// Phase 2: Stream main response content
const responseContent =
"Based on my analysis above, here's the solution:\n\n**Key Points:**\n- The approach should be systematic\n- We need to consider all factors\n- Implementation should be step-by-step\n\nThis ensures the best possible outcome.";
for (let i = 0; i < responseContent.length; i++) {
currentContent += responseContent[i];
streamingThinkMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 10));
}
streamingThinkMessage.timestamp = Date.now();
}}
>
<div class="w-[56rem]">
<ChatMessage message={streamingThinkMessage} />
</div>
</Story>
<Story
name="StreamingThinkBracket"
args={{
message: streamingBracketMessage
}}
parameters={{
test: {
timeout: 30000
}
}}
asChild
play={async () => {
// Phase 1: Stream [THINK] reasoning content
const thinkingContent =
'Using the DeepSeek format now:\n\n- This demonstrates the &#91;THINK&#93; bracket format\n- Should parse identically to &lt;think&gt; tags\n- The UI should display this in the thinking section\n- Main content should be separate\n\nBoth formats provide the same functionality.';
let currentContent = '[THINK]\n';
streamingBracketMessage.content = currentContent;
for (let i = 0; i < thinkingContent.length; i++) {
currentContent += thinkingContent[i];
streamingBracketMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 5));
}
// Close the thinking block
currentContent += '\n[/THINK]\n\n';
streamingBracketMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 200));
// Phase 2: Stream main response content
const responseContent =
"Here's my response after using the &#91;THINK&#93; format:\n\n**Observations:**\n- Both &lt;think&gt; and &#91;THINK&#93; formats work seamlessly\n- The parsing logic handles both cases\n- UI display is consistent across formats\n\nThis demonstrates the enhanced thinking content support.";
for (let i = 0; i < responseContent.length; i++) {
currentContent += responseContent[i];
streamingBracketMessage.content = currentContent;
await new Promise((resolve) => setTimeout(resolve, 10));
}
streamingBracketMessage.timestamp = Date.now();
}}
>
<div class="w-[56rem]">
<ChatMessage message={streamingBracketMessage} />
</div>
</Story>