diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 29819e48d..060578f0b 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -893,23 +893,6 @@ static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { builder.consume_reasoning_with_xml_tool_calls(form, "", ""); } -static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "") != std::string::npos); + // Handle thinking tags appropriately based on inputs.enable_thinking - if (string_ends_with(data.prompt, "\n")) { + if (supports_reasoning && string_ends_with(data.prompt, "\n")) { if (!inputs.enable_thinking) { data.prompt += ""; } else { @@ -1552,19 +1554,21 @@ static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_ } data.preserved_tokens = { - "", - "", "", "", }; + if (supports_reasoning) { + data.preserved_tokens.insert(data.preserved_tokens.end(), {"", ""}); + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto include_grammar = true; auto parser = build_chat_peg_constructed_parser([&](auto & p) { auto reasoning = p.eps(); - if (inputs.enable_thinking && extract_reasoning) { + if (supports_reasoning && inputs.enable_thinking && extract_reasoning) { auto reasoning_content = p.reasoning(p.until("")) + ("" | p.end()); if (data.thinking_forced_open) { reasoning = reasoning_content; @@ -1902,38 +1906,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t return data; } -static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML; - - data.preserved_tokens = { - "", - "", - "", - "", - }; - - // build grammar for tool call - static const xml_tool_call_format form { - /* form.scope_start = */ "\n", - /* form.tool_start = */ "\n", - /* form.key_start = */ "\n", - /* form.val_end = */ "\n\n", - /* form.tool_end = */ "\n", - /* form.scope_end = */ "", - }; - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -3161,13 +3133,7 @@ static common_chat_params common_chat_templates_apply_jinja( src.find("") != std::string::npos) { - return common_chat_params_init_nemotron_v3(tmpl, params); - } - return common_chat_params_init_qwen3_coder_xml(tmpl, params); + return common_chat_params_init_qwen3_coder(tmpl, params); } // Xiaomi MiMo format detection (must come before Hermes 2 Pro) diff --git a/common/chat.h b/common/chat.h index 1bf43f726..6f0b9409e 100644 --- a/common/chat.h +++ b/common/chat.h @@ -128,7 +128,6 @@ enum common_chat_format { COMMON_CHAT_FORMAT_GLM_4_5, COMMON_CHAT_FORMAT_MINIMAX_M2, COMMON_CHAT_FORMAT_KIMI_K2, - COMMON_CHAT_FORMAT_QWEN3_CODER_XML, COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_XIAOMI_MIMO, COMMON_CHAT_FORMAT_SOLAR_OPEN, diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index b6e989939..45fb3e42d 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -163,15 +163,9 @@ #elif defined(__riscv) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K -#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K -#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K -#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K #define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K -#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K -#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K -#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index ae0ebb3ca..bf9f4df11 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -1954,3 +1954,773 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +static const uint8_t sign_gather_indices_arr[64] = { + 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, + 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 +}; + +static const uint8_t sign_bit_masks_arr[64] = { + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 +}; + +static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // --- Pre-load Constants --- + uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8); + uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8); + + // Constants for sign extraction + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 4; ++ib) { + // Combine low + high bits + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8); + qs += 8; + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2); + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8); + + // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000 + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8); + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8); + + // Multiply by 8 to get byte offset, instead of element offset + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8); + + // Lookup Grid using Byte Offsets + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8); + + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + // Load signs and generate sign mask + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8); + signs_ptr += 8; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + scales += 2; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // Pre-load Constants + vuint8m2_t v_ids = __riscv_vid_v_u8m2(32); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32); + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32); + uint16_t shift_qh_arr[4] = {11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { + + // Load Low Bits [4 bytes] + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4); + qs += 4; + + // Load 1 byte. It contains bits for 4 mini-blocks. + uint8_t qh_val = *qh++; + + // Combine Low + High bits of 10bit indices + vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4); + vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4); + vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4); + v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4); + vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4); + vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4); + + // Lookup Grid + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4))); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4); + signs_ptr += 4; + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); + + // generating sign mask + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // apply signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32); + + // Reduction + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + // Reduce 0-15 (First Half) + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16)); + + // Reduce 16-31 (Second Half) + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16)); + + // Apply sub Scales + uint8_t sc = *scales++; + + sum_block += s0 * (2 * (sc & 0xF) + 1); + sum_block += s1 * (2 * (sc >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const uint64_t * grid64 = (const uint64_t *)iq3s_grid; + + // --- Pre-load Constants --- + const uint16_t qh_bit_shifts_arr[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + }; + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d); + const float combined_scale = d * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + + // Loop: Process 64 weights (16 mini-blocks of 4) per iteration + for (int ib = 0; ib < 4; ++ib) { + + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16); + qs += 16; + + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + + vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16); + // Extract bits: (qh >> i) & 1 + v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16); + v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16); + + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16); + v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16); + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16); + + // Grid value is 4xuint8 + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8); + signs += 8; + + // Generate sign mask + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + // Apply Signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64); + + // Reduction + vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0); + vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32)); + int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32)); + + // Apply sub-scales + uint8_t sc_byte = *scales++; + int sc_lo = (sc_byte & 0xF) * 2 + 1; + int sc_hi = (sc_byte >> 4) * 2 + 1; + + sum_block += s_lo * sc_lo + s_hi * sc_hi; + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint32m4_t suml1; + { + const int vl = 32; + vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl); + vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl); + vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl); + vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl); + vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl); + + vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl); + vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl); + vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl); + vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl); + vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl); + + vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl); + vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl); + vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl); + vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl); + + vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl); + vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint32m2_t suml2; + { + const int vl = 16; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl); + vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint32m2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4)); + + vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl); + } + + vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16); + sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16); + sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16); + + vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} + +void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { + const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; + const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; + const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; + const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; + const uint8_t* px = &x[i].qs[j]; + + size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32); + vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2); + + size_t vl = __riscv_vsetvl_e8m1(32); + + vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl); + + vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl); + vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl); + vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl); + vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl); + + // l=0 (bits 1:0) + vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl); + vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl); + + // l=1 (bits 3:2) + vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl); + vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl); + + // l=2 (bits 5:4) + vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl); + vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl); + + // l=3 (bits 7:6) + vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros + vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl); + + // 4. Multiply and accumulate + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl); + + vlmax_16m2 = __riscv_vsetvl_e16m2(32); + vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2); + + sumi += __riscv_vmv_x_s_i32m1_i32(vred32); + } + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + sumf += (float)sumi * d; + } + + *s = sumf; +} + +void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8); + + // Calculate ls. + vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8); + temp = __riscv_vand_vx_u16mf2(temp, 7, 8); + vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8)); + ls = __riscv_vadd_vx_i32m1(ls, 1, 8); + + // Calculate delta. + vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8); + vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8); + vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8); + vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8)); + vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2( + __riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32)); + vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh)); + vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32); + vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-4 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 5-8 + { + vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1); + vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16)); + vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128); + vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8); + + // Calculate the bsums. + vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16); + const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0)); + const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8)); + const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8)); + const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8); + vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16); + + // We process 4 sub-blocks together. + for (int ib = 0; ib < QK_K/128; ib++) { + // Load qh for 4 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16); + qh += 8; + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16); + index = __riscv_vsll_vx_u16m1(index, 3, 16); + qs += 16; + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16); + const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128); + + // Prepare the scales. + const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1; + sc += 2; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 1e8b50321..86e91b84d 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1140,7 +1140,6 @@ struct ggml_cuda_graph_node_properties { }; static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial"); -static bool cugraph_warned_rec = false; struct ggml_cuda_graph { #ifdef USE_CUDA_GRAPH ~ggml_cuda_graph() { @@ -1156,8 +1155,7 @@ struct ggml_cuda_graph { size_t num_nodes = 0; std::vector nodes; bool disable_due_to_gpu_arch = false; - bool disable_due_to_too_many_updates = false; - int number_consecutive_updates = 0; + bool warmup_complete = false; std::vector props; // these are extra tensors (inputs) that participate in the ggml graph but are not nodes @@ -1166,25 +1164,9 @@ struct ggml_cuda_graph { // ref: https://github.com/ggml-org/llama.cpp/pull/19165 std::vector extra; - void record_update(bool use_graph, bool update_required) { - if (use_graph && update_required) { - number_consecutive_updates++; - } else { - number_consecutive_updates = 0; - } - if (number_consecutive_updates >= 4) { - if(!cugraph_warned_rec) - { - cugraph_warned_rec = true; - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); - } - disable_due_to_too_many_updates = true; - } - } - bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env); } #endif }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c611a7766..0c78a141f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3014,10 +3014,6 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx const void * graph_key = ggml_cuda_graph_get_key(cgraph); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - if (graph->instance == nullptr) { - res = true; - } - // Check if the graph size has changed if (graph->props.size() != (size_t)cgraph->n_nodes) { res = true; @@ -3971,14 +3967,35 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, #ifdef USE_CUDA_GRAPH graph_key = ggml_cuda_graph_get_key(cgraph); - use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); + ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (graph->is_enabled()) { - cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); - use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); + const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph); + if (graph_compatible) { + const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph); - graph->record_update(use_cuda_graph, cuda_graph_update_required); + if (!graph->warmup_complete) { + // Warmup: need at least 2 calls with no property change on the 2nd call + if (!properties_changed) { + graph->warmup_complete = true; + GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__); + use_cuda_graph = true; + cuda_graph_update_required = true; + } + // else: properties changed or first call - execute directly (use_cuda_graph stays false) + } else { + // Post-warmup: normal CUDA graph operation + if (properties_changed) { + // Properties changed - reset warmup, execute directly until stable again + graph->warmup_complete = false; + GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__); + } else { + use_cuda_graph = true; + cuda_graph_update_required = graph->instance == nullptr; + } + } + } } #endif // USE_CUDA_GRAPH diff --git a/include/llama.h b/include/llama.h index 2667f685c..f273a3b11 100644 --- a/include/llama.h +++ b/include/llama.h @@ -392,6 +392,7 @@ extern "C" { bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored bool pure; // quantize all tensors to the default type bool keep_split; // quantize to the same number of shards + bool dry_run; // calculate and show the final quantization size without performing quantization void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides void * tensor_types; // pointer to vector containing tensor types diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp index 6a97faf07..2d6f9e7c5 100644 --- a/src/llama-impl.cpp +++ b/src/llama-impl.cpp @@ -109,9 +109,9 @@ std::string llama_format_tensor_shape(const std::vector & ne) { std::string llama_format_tensor_shape(const struct ggml_tensor * t) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + snprintf(buf, sizeof(buf), "%6" PRId64, t->ne[0]); for (int i = 1; i < GGML_MAX_DIMS; i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, t->ne[i]); } return buf; } diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 680d8852c..9a176b290 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -479,6 +479,17 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } +static bool tensor_type_requires_imatrix(const ggml_tensor * t, const ggml_type dst_type, const llama_ftype ftype) { + return ( + dst_type == GGML_TYPE_IQ2_XXS || dst_type == GGML_TYPE_IQ2_XS || + dst_type == GGML_TYPE_IQ3_XXS || dst_type == GGML_TYPE_IQ1_S || + dst_type == GGML_TYPE_IQ2_S || dst_type == GGML_TYPE_IQ1_M || + ( // Q2_K_S is the worst k-quant type - only allow it without imatrix for token embeddings + dst_type == GGML_TYPE_Q2_K && ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(t->name, "token_embd.weight") != 0 + ) + ); +} + static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype; @@ -735,24 +746,36 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }; const auto tn = LLM_TN(model.arch); - new_ofstream(0); + + // no output file for --dry-run + if (!params->dry_run) { + new_ofstream(0); + } + + // flag for `--dry-run`, to let the user know if imatrix will be required for a real + // quantization, as a courtesy + bool will_require_imatrix = false; + for (const auto * it : tensors) { const auto & weight = *it; ggml_tensor * tensor = weight.tensor; - if (weight.idx != cur_split && params->keep_split) { + if (!params->dry_run && (weight.idx != cur_split && params->keep_split)) { close_ofstream(); new_ofstream(weight.idx); } const std::string name = ggml_get_name(tensor); + const size_t tensor_size = ggml_nbytes(tensor); - if (!ml.use_mmap) { - if (read_data.size() < ggml_nbytes(tensor)) { - read_data.resize(ggml_nbytes(tensor)); + if (!params->dry_run) { + if (!ml.use_mmap) { + if (read_data.size() < tensor_size) { + read_data.resize(tensor_size); + } + tensor->data = read_data.data(); } - tensor->data = read_data.data(); + ml.load_data_for(tensor); } - ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", ++idx, ml.n_tensors, @@ -903,129 +926,155 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize = tensor->type != new_type; } - if (!quantize) { - new_type = tensor->type; - new_data = tensor->data; - new_size = ggml_nbytes(tensor); - LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0); - } else { - const int64_t nelements = ggml_nelements(tensor); - - const float * imatrix = nullptr; - if (imatrix_data) { - auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); - if (it == imatrix_data->end()) { - LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); - } else { - if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { - imatrix = it->second.data(); - } else { - LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); - - // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix - // this is a significant error and it may be good idea to abort the process if this happens, - // since many people will miss the error and not realize that most of the model is being quantized without an imatrix - // tok_embd should be ignored in this case, since it always causes this warning - if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { - throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); - } - } + // we have now decided on the target type for this tensor + if (params->dry_run) { + // the --dry-run option calculates the final quantization size without quantizting + if (quantize) { + new_size = ggml_nrows(tensor) * ggml_row_size(new_type, tensor->ne[0]); + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB (%s)\n", + tensor_size/1024.0/1024.0, + new_size/1024.0/1024.0, + ggml_type_name(new_type)); + if (!will_require_imatrix && tensor_type_requires_imatrix(tensor, new_type, params->ftype)) { + will_require_imatrix = true; } - } - if ((new_type == GGML_TYPE_IQ2_XXS || - new_type == GGML_TYPE_IQ2_XS || - new_type == GGML_TYPE_IQ2_S || - new_type == GGML_TYPE_IQ1_S || - (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || - (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { - LLAMA_LOG_ERROR("\n\n============================================================\n"); - LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); - LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); - LLAMA_LOG_ERROR("============================================================\n\n"); - throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); - } - - float * f32_data; - - if (tensor->type == GGML_TYPE_F32) { - f32_data = (float *) tensor->data; - } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { - throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); } else { - llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); - f32_data = (float *) f32_conv_buf.data(); + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", new_size/1024.0/1024.0); } + total_size_org += tensor_size; + total_size_new += new_size; + continue; + } else { + // no --dry-run, perform quantization + if (!quantize) { + new_type = tensor->type; + new_data = tensor->data; + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", tensor_size/1024.0/1024.0); + } else { + const int64_t nelements = ggml_nelements(tensor); - LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); - fflush(stdout); + const float * imatrix = nullptr; + if (imatrix_data) { + auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); + if (it == imatrix_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); + } else { + if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { + imatrix = it->second.data(); + } else { + LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); - if (work.size() < (size_t)nelements * 4) { - work.resize(nelements * 4); // upper bound on size - } - new_data = work.data(); - - const int64_t n_per_row = tensor->ne[0]; - const int64_t nrows = tensor->ne[1]; - - static const int64_t min_chunk_size = 32 * 512; - const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); - - const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; - const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; - const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; - - // quantize each expert separately since they have different importance matrices - new_size = 0; - for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { - const float * f32_data_03 = f32_data + i03 * nelements_matrix; - void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; - const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; - - new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); - - // TODO: temporary sanity check that the F16 -> MXFP4 is lossless -#if 0 - if (new_type == GGML_TYPE_MXFP4) { - auto * x = f32_data_03; - - //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); - std::vector deq(nrows*n_per_row); - const ggml_type_traits * qtype = ggml_get_type_traits(new_type); - qtype->to_float(new_data_03, deq.data(), deq.size()); - - double err = 0.0f; - for (int i = 0; i < (int) deq.size(); ++i) { - err += fabsf(deq[i] - x[i]); - //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { - if (deq[i] != x[i]) { - LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); + // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix + // this is a significant error and it may be good idea to abort the process if this happens, + // since many people will miss the error and not realize that most of the model is being quantized without an imatrix + // tok_embd should be ignored in this case, since it always causes this warning + if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { + throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + } } } - //LLAMA_LOG_INFO("err = %f\n", err); - GGML_ASSERT(err == 0.00000); } + if (!imatrix && tensor_type_requires_imatrix(tensor, new_type, params->ftype)) { + LLAMA_LOG_ERROR("\n\n============================================================\n"); + LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); + LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); + LLAMA_LOG_ERROR("============================================================\n\n"); + throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); + } + + float * f32_data; + + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); + } else { + llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); + } + + LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); + fflush(stdout); + + if (work.size() < (size_t)nelements * 4) { + work.resize(nelements * 4); // upper bound on size + } + new_data = work.data(); + + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; + + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); + + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; + + // quantize each expert separately since they have different importance matrices + new_size = 0; + for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { + const float * f32_data_03 = f32_data + i03 * nelements_matrix; + void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; + const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; + + new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + + // TODO: temporary sanity check that the F16 -> MXFP4 is lossless +#if 0 + if (new_type == GGML_TYPE_MXFP4) { + auto * x = f32_data_03; + + //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); + std::vector deq(nrows*n_per_row); + const ggml_type_traits * qtype = ggml_get_type_traits(new_type); + qtype->to_float(new_data_03, deq.data(), deq.size()); + + double err = 0.0f; + for (int i = 0; i < (int) deq.size(); ++i) { + err += fabsf(deq[i] - x[i]); + //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { + if (deq[i] != x[i]) { + LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); + } + } + //LLAMA_LOG_INFO("err = %f\n", err); + GGML_ASSERT(err == 0.00000); + } #endif + } + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", tensor_size/1024.0/1024.0, new_size/1024.0/1024.0); } - LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); - } - total_size_org += ggml_nbytes(tensor); - total_size_new += new_size; + total_size_org += tensor_size; + total_size_new += new_size; - // update the gguf meta data as we go - gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); - gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); - // write tensor data + padding - fout.write((const char *) new_data, new_size); - zeros(fout, GGML_PAD(new_size, align) - new_size); + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } // no --dry-run + } // iterate over tensors + + if (!params->dry_run) { + close_ofstream(); } - close_ofstream(); - LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0); - LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0); + LLAMA_LOG_INFO("%s: model size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_org/1024.0/1024.0, total_size_org*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: quant size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_new/1024.0/1024.0, total_size_new*8.0/ml.n_elements); + + if (!params->imatrix && params->dry_run && will_require_imatrix) { + LLAMA_LOG_WARN("%s: WARNING: dry run completed successfully, but actually completing this quantization will require an imatrix!\n", + __func__ + ); + } if (qs.n_fallback > 0) { LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", @@ -1048,6 +1097,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.only_copy =*/ false, /*.pure =*/ false, /*.keep_split =*/ false, + /*.dry_run =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index d3755bcc8..9609ea32e 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -121,7 +121,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp static void usage(const char * executable) { printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--tensor-type-file]\n"); - printf(" [--prune-layers] [--keep-split] [--override-kv]\n"); + printf(" [--prune-layers] [--keep-split] [--override-kv] [--dry-run]\n"); printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); printf(" --allow-requantize\n"); printf(" allow requantizing tensors that have already been quantized\n"); @@ -157,7 +157,10 @@ static void usage(const char * executable) { printf(" generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" override model metadata by key in the quantized model. may be specified multiple times.\n"); - printf(" WARNING: this is an advanced option, use with care.\n\n"); + printf(" WARNING: this is an advanced option, use with care.\n"); + printf(" --dry-run\n"); + printf(" calculate and show the final quantization size without performing quantization\n"); + printf(" example: llama-quantize --dry-run model-f32.gguf Q4_K\n\n"); printf("note: --include-weights and --exclude-weights cannot be used together\n\n"); printf("-----------------------------------------------------------------------------\n"); printf(" allowed quantization types\n"); @@ -533,6 +536,8 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--dry-run") == 0) { + params.dry_run = true; } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) { params.allow_requantize = true; } else if (strcmp(argv[arg_idx], "--pure") == 0) { @@ -631,22 +636,26 @@ int main(int argc, char ** argv) { std::string ftype_str; std::string suffix = ".gguf"; if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { - std::string fpath; - const size_t pos = fname_inp.find_last_of("/\\"); - if (pos != std::string::npos) { - fpath = fname_inp.substr(0, pos + 1); - } + // argv[arg_idx] is the ftype directly: + if (!params.dry_run) { + std::string fpath; + const size_t pos = fname_inp.find_last_of("/\\"); + if (pos != std::string::npos) { + fpath = fname_inp.substr(0, pos + 1); + } - // export as [inp path]/ggml-model-[ftype]. Only add extension if there is no splitting - fname_out = fpath + "ggml-model-" + ftype_str; - if (!params.keep_split) { - fname_out += suffix; + // export as [inp path]/ggml-model-[ftype]. Only add extension if there is no splitting + fname_out = fpath + "ggml-model-" + ftype_str; + if (!params.keep_split) { + fname_out += suffix; + } } arg_idx++; if (ftype_str == "COPY") { params.only_copy = true; } } else { + // argv[arg_idx] is not a valid ftype, so treat it as output path: fname_out = argv[arg_idx]; if (params.keep_split && fname_out.find(suffix) != std::string::npos) { fname_out = fname_out.substr(0, fname_out.length() - suffix.length()); @@ -678,25 +687,33 @@ int main(int argc, char ** argv) { } } - if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || - params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || - params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || - params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || - params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && imatrix_data.empty()) { + if (!params.dry_run && + ( + params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || + params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || + params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M + ) && imatrix_data.empty()) { fprintf(stderr, "\n==========================================================================================================\n"); fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); fprintf(stderr, "==========================================================================================================\n\n\n"); return 1; } - if (std::error_code ec; std::filesystem::equivalent(fname_inp, fname_out, ec)) { - fprintf(stderr, "%s: error: input and output files are the same: '%s'\n", __func__, fname_inp.c_str()); - return 1; + if (!params.dry_run) { + if (std::error_code ec; std::filesystem::equivalent(fname_inp, fname_out, ec)) { + fprintf(stderr, "%s: error: input and output files are the same: '%s'\n", __func__, fname_inp.c_str()); + return 1; + } } print_build_info(); - fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str()); + if (params.dry_run) { + fprintf(stderr, "%s: calculating quantization size for '%s' as %s", __func__, fname_inp.c_str(), ftype_str.c_str()); + } else { + fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str()); + } + if (params.nthread > 0) { fprintf(stderr, " using %d threads", params.nthread); } diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index c69481e79..2ead90dfb 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte index e011fa6ec..ebffae121 100644 --- a/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte +++ b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte @@ -251,9 +251,6 @@ return options.find((option) => option.id === activeId); } - if (options.length === 1) { - return options[0]; - } // No selection - return undefined to show "Select model" return undefined; } diff --git a/tools/server/webui/src/lib/stores/models.svelte.ts b/tools/server/webui/src/lib/stores/models.svelte.ts index 4cb616722..c4cc3d386 100644 --- a/tools/server/webui/src/lib/stores/models.svelte.ts +++ b/tools/server/webui/src/lib/stores/models.svelte.ts @@ -306,6 +306,16 @@ class ModelsStore { const response = await ModelsService.listRouter(); this.routerModels = response.data; await this.fetchModalitiesForLoadedModels(); + + const o = this.models.filter((option) => { + const modelProps = this.getModelProps(option.model); + + return modelProps?.webui !== false; + }); + + if (o.length === 1 && this.isModelLoaded(o[0].model)) { + this.selectModelById(o[0].id); + } } catch (error) { console.warn('Failed to fetch router models:', error); this.routerModels = []; diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 9d24594f9..d929d2119 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1,5 +1,7 @@ #include "httplib.h" namespace httplib { +// httplib::any — type-erased value container (C++11 compatible) +// On C++17+ builds, thin wrappers around std::any are provided. /* * Implementation that will be part of the .cc file if split into .h + .cc. @@ -630,6 +632,56 @@ size_t to_utf8(int code, char *buff) { return 0; } +} // namespace detail + +namespace ws { +namespace impl { + +bool is_valid_utf8(const std::string &s) { + size_t i = 0; + auto n = s.size(); + while (i < n) { + auto c = static_cast(s[i]); + size_t len; + uint32_t cp; + if (c < 0x80) { + i++; + continue; + } else if ((c & 0xE0) == 0xC0) { + len = 2; + cp = c & 0x1F; + } else if ((c & 0xF0) == 0xE0) { + len = 3; + cp = c & 0x0F; + } else if ((c & 0xF8) == 0xF0) { + len = 4; + cp = c & 0x07; + } else { + return false; + } + if (i + len > n) { return false; } + for (size_t j = 1; j < len; j++) { + auto b = static_cast(s[i + j]); + if ((b & 0xC0) != 0x80) { return false; } + cp = (cp << 6) | (b & 0x3F); + } + // Overlong encoding check + if (len == 2 && cp < 0x80) { return false; } + if (len == 3 && cp < 0x800) { return false; } + if (len == 4 && cp < 0x10000) { return false; } + // Surrogate halves (U+D800..U+DFFF) and beyond U+10FFFF are invalid + if (cp >= 0xD800 && cp <= 0xDFFF) { return false; } + if (cp > 0x10FFFF) { return false; } + i += len; + } + return true; +} + +} // namespace impl +} // namespace ws + +namespace detail { + // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c std::string base64_encode(const std::string &in) { @@ -660,6 +712,281 @@ std::string base64_encode(const std::string &in) { return out; } +std::string sha1(const std::string &input) { + // RFC 3174 SHA-1 implementation + auto left_rotate = [](uint32_t x, uint32_t n) -> uint32_t { + return (x << n) | (x >> (32 - n)); + }; + + uint32_t h0 = 0x67452301; + uint32_t h1 = 0xEFCDAB89; + uint32_t h2 = 0x98BADCFE; + uint32_t h3 = 0x10325476; + uint32_t h4 = 0xC3D2E1F0; + + // Pre-processing: adding padding bits + std::string msg = input; + uint64_t original_bit_len = static_cast(msg.size()) * 8; + msg.push_back(static_cast(0x80)); + while (msg.size() % 64 != 56) { + msg.push_back(0); + } + + // Append original length in bits as 64-bit big-endian + for (int i = 56; i >= 0; i -= 8) { + msg.push_back(static_cast((original_bit_len >> i) & 0xFF)); + } + + // Process each 512-bit chunk + for (size_t offset = 0; offset < msg.size(); offset += 64) { + uint32_t w[80]; + + for (size_t i = 0; i < 16; i++) { + w[i] = + (static_cast(static_cast(msg[offset + i * 4])) + << 24) | + (static_cast(static_cast(msg[offset + i * 4 + 1])) + << 16) | + (static_cast(static_cast(msg[offset + i * 4 + 2])) + << 8) | + (static_cast( + static_cast(msg[offset + i * 4 + 3]))); + } + + for (int i = 16; i < 80; i++) { + w[i] = left_rotate(w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16], 1); + } + + uint32_t a = h0, b = h1, c = h2, d = h3, e = h4; + + for (int i = 0; i < 80; i++) { + uint32_t f, k; + if (i < 20) { + f = (b & c) | ((~b) & d); + k = 0x5A827999; + } else if (i < 40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i < 60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + + uint32_t temp = left_rotate(a, 5) + f + e + k + w[i]; + e = d; + d = c; + c = left_rotate(b, 30); + b = a; + a = temp; + } + + h0 += a; + h1 += b; + h2 += c; + h3 += d; + h4 += e; + } + + // Produce the final hash as a 20-byte binary string + std::string hash(20, '\0'); + for (size_t i = 0; i < 4; i++) { + hash[i] = static_cast((h0 >> (24 - i * 8)) & 0xFF); + hash[4 + i] = static_cast((h1 >> (24 - i * 8)) & 0xFF); + hash[8 + i] = static_cast((h2 >> (24 - i * 8)) & 0xFF); + hash[12 + i] = static_cast((h3 >> (24 - i * 8)) & 0xFF); + hash[16 + i] = static_cast((h4 >> (24 - i * 8)) & 0xFF); + } + return hash; +} + +std::string websocket_accept_key(const std::string &client_key) { + const std::string magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + return base64_encode(sha1(client_key + magic)); +} + +bool is_websocket_upgrade(const Request &req) { + if (req.method != "GET") { return false; } + + // Check Upgrade: websocket (case-insensitive) + auto upgrade_it = req.headers.find("Upgrade"); + if (upgrade_it == req.headers.end()) { return false; } + auto upgrade_val = upgrade_it->second; + std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(), + ::tolower); + if (upgrade_val != "websocket") { return false; } + + // Check Connection header contains "Upgrade" + auto connection_it = req.headers.find("Connection"); + if (connection_it == req.headers.end()) { return false; } + auto connection_val = connection_it->second; + std::transform(connection_val.begin(), connection_val.end(), + connection_val.begin(), ::tolower); + if (connection_val.find("upgrade") == std::string::npos) { return false; } + + // Check Sec-WebSocket-Key is a valid base64-encoded 16-byte value (24 chars) + // RFC 6455 Section 4.2.1 + auto ws_key = req.get_header_value("Sec-WebSocket-Key"); + if (ws_key.size() != 24 || ws_key[22] != '=' || ws_key[23] != '=') { + return false; + } + static const std::string b64chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + for (size_t i = 0; i < 22; i++) { + if (b64chars.find(ws_key[i]) == std::string::npos) { return false; } + } + + // Check Sec-WebSocket-Version: 13 + auto version = req.get_header_value("Sec-WebSocket-Version"); + if (version != "13") { return false; } + + return true; +} + +bool write_websocket_frame(Stream &strm, ws::Opcode opcode, + const char *data, size_t len, bool fin, + bool mask) { + // First byte: FIN + opcode + uint8_t header[2]; + header[0] = static_cast((fin ? 0x80 : 0x00) | + (static_cast(opcode) & 0x0F)); + + // Second byte: MASK + payload length + if (len < 126) { + header[1] = static_cast(len); + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + } else if (len <= 0xFFFF) { + header[1] = 126; + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + uint8_t ext[2]; + ext[0] = static_cast((len >> 8) & 0xFF); + ext[1] = static_cast(len & 0xFF); + if (strm.write(reinterpret_cast(ext), 2) < 0) { return false; } + } else { + header[1] = 127; + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + uint8_t ext[8]; + for (int i = 7; i >= 0; i--) { + ext[7 - i] = static_cast((len >> (i * 8)) & 0xFF); + } + if (strm.write(reinterpret_cast(ext), 8) < 0) { return false; } + } + + if (mask) { + // Generate random mask key + thread_local std::mt19937 rng(std::random_device{}()); + uint8_t mask_key[4]; + auto r = rng(); + std::memcpy(mask_key, &r, 4); + if (strm.write(reinterpret_cast(mask_key), 4) < 0) { return false; } + + // Write masked payload in chunks + const size_t chunk_size = 4096; + std::vector buf((std::min)(len, chunk_size)); + for (size_t offset = 0; offset < len; offset += chunk_size) { + size_t n = (std::min)(chunk_size, len - offset); + for (size_t i = 0; i < n; i++) { + buf[i] = + data[offset + i] ^ static_cast(mask_key[(offset + i) % 4]); + } + if (strm.write(buf.data(), n) < 0) { return false; } + } + } else { + if (len > 0) { + if (strm.write(data, len) < 0) { return false; } + } + } + + return true; +} + +} // namespace detail + +namespace ws { +namespace impl { + +bool read_websocket_frame(Stream &strm, Opcode &opcode, + std::string &payload, bool &fin, + bool expect_masked, size_t max_len) { + // Read first 2 bytes + uint8_t header[2]; + if (strm.read(reinterpret_cast(header), 2) != 2) { return false; } + + fin = (header[0] & 0x80) != 0; + + // RSV1, RSV2, RSV3 must be 0 when no extension is negotiated + if (header[0] & 0x70) { return false; } + + opcode = static_cast(header[0] & 0x0F); + bool masked = (header[1] & 0x80) != 0; + uint64_t payload_len = header[1] & 0x7F; + + // RFC 6455 Section 5.5: control frames MUST NOT be fragmented and + // MUST have a payload length of 125 bytes or less + bool is_control = (static_cast(opcode) & 0x08) != 0; + if (is_control) { + if (!fin) { return false; } + if (payload_len > 125) { return false; } + } + + if (masked != expect_masked) { return false; } + + // Extended payload length + if (payload_len == 126) { + uint8_t ext[2]; + if (strm.read(reinterpret_cast(ext), 2) != 2) { return false; } + payload_len = (static_cast(ext[0]) << 8) | ext[1]; + } else if (payload_len == 127) { + uint8_t ext[8]; + if (strm.read(reinterpret_cast(ext), 8) != 8) { return false; } + // RFC 6455 Section 5.2: the most significant bit MUST be 0 + if (ext[0] & 0x80) { return false; } + payload_len = 0; + for (int i = 0; i < 8; i++) { + payload_len = (payload_len << 8) | ext[i]; + } + } + + if (payload_len > max_len) { return false; } + + // Read mask key if present + uint8_t mask_key[4] = {0}; + if (masked) { + if (strm.read(reinterpret_cast(mask_key), 4) != 4) { return false; } + } + + // Read payload + payload.resize(static_cast(payload_len)); + if (payload_len > 0) { + size_t total_read = 0; + while (total_read < payload_len) { + auto n = strm.read(&payload[total_read], + static_cast(payload_len - total_read)); + if (n <= 0) { return false; } + total_read += static_cast(n); + } + } + + // Unmask if needed + if (masked) { + for (size_t i = 0; i < payload.size(); i++) { + payload[i] ^= static_cast(mask_key[i % 4]); + } + } + + return true; +} + +} // namespace impl +} // namespace ws + +namespace detail { + bool is_valid_path(const std::string &path) { size_t level = 0; size_t i = 0; @@ -1339,6 +1666,7 @@ public: void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; time_t duration() const override; + void set_read_timeout(time_t sec, time_t usec = 0) override; private: socket_t sock_; @@ -2653,6 +2981,50 @@ bool read_headers(Stream &strm, Headers &headers) { return true; } +bool read_websocket_upgrade_response(Stream &strm, + const std::string &expected_accept, + std::string &selected_subprotocol) { + // Read status line + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + if (!line_reader.getline()) { return false; } + + // Check for "HTTP/1.1 101" + auto line = std::string(line_reader.ptr(), line_reader.size()); + if (line.find("HTTP/1.1 101") == std::string::npos) { return false; } + + // Parse headers using existing read_headers + Headers headers; + if (!read_headers(strm, headers)) { return false; } + + // Verify Upgrade: websocket (case-insensitive) + auto upgrade_it = headers.find("Upgrade"); + if (upgrade_it == headers.end()) { return false; } + auto upgrade_val = upgrade_it->second; + std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(), + ::tolower); + if (upgrade_val != "websocket") { return false; } + + // Verify Connection header contains "Upgrade" (case-insensitive) + auto connection_it = headers.find("Connection"); + if (connection_it == headers.end()) { return false; } + auto connection_val = connection_it->second; + std::transform(connection_val.begin(), connection_val.end(), + connection_val.begin(), ::tolower); + if (connection_val.find("upgrade") == std::string::npos) { return false; } + + // Verify Sec-WebSocket-Accept header value + auto it = headers.find("Sec-WebSocket-Accept"); + if (it == headers.end() || it->second != expected_accept) { return false; } + + // Extract negotiated subprotocol + auto proto_it = headers.find("Sec-WebSocket-Protocol"); + if (proto_it != headers.end()) { selected_subprotocol = proto_it->second; } + + return true; +} + enum class ReadContentResult { Success, // Successfully read the content PayloadTooLarge, // The content exceeds the specified payload limit @@ -3772,6 +4144,73 @@ serialize_multipart_formdata(const UploadFormDataItems &items, return body; } +size_t get_multipart_content_length(const UploadFormDataItems &items, + const std::string &boundary) { + size_t total = 0; + for (const auto &item : items) { + total += serialize_multipart_formdata_item_begin(item, boundary).size(); + total += item.content.size(); + total += serialize_multipart_formdata_item_end().size(); + } + total += serialize_multipart_formdata_finish(boundary).size(); + return total; +} + +struct MultipartSegment { + const char *data; + size_t size; +}; + +// NOTE: items must outlive the returned ContentProvider +// (safe for synchronous use inside Post/Put/Patch) +ContentProvider +make_multipart_content_provider(const UploadFormDataItems &items, + const std::string &boundary) { + // Own the per-item header strings and the finish string + std::vector owned; + owned.reserve(items.size() + 1); + for (const auto &item : items) + owned.push_back(serialize_multipart_formdata_item_begin(item, boundary)); + owned.push_back(serialize_multipart_formdata_finish(boundary)); + + // Flat segment list: [header, content, "\r\n"] * N + [finish] + std::vector segs; + segs.reserve(items.size() * 3 + 1); + static const char crlf[] = "\r\n"; + for (size_t i = 0; i < items.size(); i++) { + segs.push_back({owned[i].data(), owned[i].size()}); + segs.push_back({items[i].content.data(), items[i].content.size()}); + segs.push_back({crlf, 2}); + } + segs.push_back({owned.back().data(), owned.back().size()}); + + struct MultipartState { + std::vector owned; + std::vector segs; + }; + auto state = std::make_shared(); + state->owned = std::move(owned); + // `segs` holds raw pointers into owned strings; std::string move preserves + // the data pointer, so these pointers remain valid after the move above. + state->segs = std::move(segs); + + return [state](size_t offset, size_t length, DataSink &sink) -> bool { + size_t pos = 0; + for (const auto &seg : state->segs) { + // Loop invariant: pos <= offset (proven by advancing pos only when + // offset - pos >= seg.size, i.e., the segment doesn't contain offset) + if (seg.size > 0 && offset - pos < seg.size) { + size_t seg_offset = offset - pos; + size_t available = seg.size - seg_offset; + size_t to_write = (std::min)(available, length); + return sink.write(seg.data + seg_offset, to_write); + } + pos += seg.size; + } + return true; // past end (shouldn't be reached when content_length is exact) + }; +} + void coalesce_ranges(Ranges &ranges, size_t content_length) { if (ranges.size() <= 1) return; @@ -4020,15 +4459,6 @@ bool expect_content(const Request &req) { return false; } -bool has_crlf(const std::string &s) { - auto p = s.c_str(); - while (*p) { - if (*p == '\r' || *p == '\n') { return true; } - p++; - } - return false; -} - #ifdef _WIN32 class WSInit { public: @@ -4148,6 +4578,52 @@ bool is_field_content(const std::string &s) { bool is_field_value(const std::string &s) { return is_field_content(s); } } // namespace fields + +bool perform_websocket_handshake(Stream &strm, const std::string &host, + int port, const std::string &path, + const Headers &headers, + std::string &selected_subprotocol) { + // Validate path and host + if (!fields::is_field_value(path) || !fields::is_field_value(host)) { + return false; + } + + // Validate user-provided headers + for (const auto &h : headers) { + if (!fields::is_field_name(h.first) || !fields::is_field_value(h.second)) { + return false; + } + } + + // Generate random Sec-WebSocket-Key + thread_local std::mt19937 rng(std::random_device{}()); + std::string key_bytes(16, '\0'); + for (size_t i = 0; i < 16; i += 4) { + auto r = rng(); + std::memcpy(&key_bytes[i], &r, (std::min)(size_t(4), size_t(16 - i))); + } + auto client_key = base64_encode(key_bytes); + + // Build upgrade request + std::string req_str = "GET " + path + " HTTP/1.1\r\n"; + req_str += "Host: " + host + ":" + std::to_string(port) + "\r\n"; + req_str += "Upgrade: websocket\r\n"; + req_str += "Connection: Upgrade\r\n"; + req_str += "Sec-WebSocket-Key: " + client_key + "\r\n"; + req_str += "Sec-WebSocket-Version: 13\r\n"; + for (const auto &h : headers) { + req_str += h.first + ": " + h.second + "\r\n"; + } + req_str += "\r\n"; + + if (strm.write(req_str.data(), req_str.size()) < 0) { return false; } + + // Verify 101 response and Sec-WebSocket-Accept header + auto expected_accept = websocket_accept_key(client_key); + return read_websocket_upgrade_response(strm, expected_accept, + selected_subprotocol); +} + } // namespace detail /* @@ -4176,6 +4652,7 @@ public: void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; time_t duration() const override; + void set_read_timeout(time_t sec, time_t usec = 0) override; private: socket_t sock_; @@ -4268,6 +4745,39 @@ std::string SHA_512(const std::string &s) { #endif return hash_to_hex(hash); } +#elif defined(CPPHTTPLIB_WOLFSSL_SUPPORT) +namespace { +template +std::string hash_to_hex(const unsigned char (&hash)[N]) { + std::stringstream ss; + for (size_t i = 0; i < N; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + return ss.str(); +} +} // namespace + +std::string MD5(const std::string &s) { + unsigned char hash[WC_MD5_DIGEST_SIZE]; + wc_Md5Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} + +std::string SHA_256(const std::string &s) { + unsigned char hash[WC_SHA256_DIGEST_SIZE]; + wc_Sha256Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} + +std::string SHA_512(const std::string &s) { + unsigned char hash[WC_SHA512_DIGEST_SIZE]; + wc_Sha512Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} #endif bool is_ip_address(const std::string &host) { @@ -4510,6 +5020,53 @@ bool verify_cert_with_windows_schannel( } #endif // _WIN32 +bool setup_client_tls_session(const std::string &host, tls::ctx_t &ctx, + tls::session_t &session, socket_t sock, + bool server_certificate_verification, + const std::string &ca_cert_file_path, + tls::ca_store_t ca_cert_store, + time_t timeout_sec, time_t timeout_usec) { + using namespace tls; + + ctx = create_client_context(); + if (!ctx) { return false; } + + if (server_certificate_verification) { + if (!ca_cert_file_path.empty()) { + load_ca_file(ctx, ca_cert_file_path.c_str()); + } + if (ca_cert_store) { set_ca_store(ctx, ca_cert_store); } + load_system_certs(ctx); + } + + bool is_ip = is_ip_address(host); + +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT + if (is_ip && server_certificate_verification) { + set_verify_client(ctx, false); + } else { + set_verify_client(ctx, server_certificate_verification); + } +#endif + + session = create_session(ctx, sock); + if (!session) { return false; } + + // RFC 6066: SNI must not be set for IP addresses + if (!is_ip) { set_sni(session, host.c_str()); } + if (server_certificate_verification) { set_hostname(session, host.c_str()); } + + if (!connect_nonblocking(session, sock, timeout_sec, timeout_usec, nullptr)) { + return false; + } + + if (server_certificate_verification) { + if (get_verify_result(session) != 0) { return false; } + } + + return true; +} + } // namespace detail #endif // CPPHTTPLIB_SSL_ENABLED @@ -5327,22 +5884,37 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) { } // ThreadPool implementation -ThreadPool::ThreadPool(size_t n, size_t mqr) - : shutdown_(false), max_queued_requests_(mqr) { - threads_.reserve(n); - while (n) { - threads_.emplace_back(worker(*this)); - n--; +ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr) + : base_thread_count_(n), max_queued_requests_(mqr), idle_thread_count_(0), + shutdown_(false) { +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (max_n != 0 && max_n < n) { + std::string msg = "max_threads must be >= base_threads"; + throw std::invalid_argument(msg); + } +#endif + max_thread_count_ = max_n == 0 ? n : max_n; + threads_.reserve(base_thread_count_); + for (size_t i = 0; i < base_thread_count_; i++) { + threads_.emplace_back(std::thread([this]() { worker(false); })); } } bool ThreadPool::enqueue(std::function fn) { { std::unique_lock lock(mutex_); + if (shutdown_) { return false; } if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { return false; } jobs_.push_back(std::move(fn)); + + // Spawn a dynamic thread if no idle threads and under max + if (idle_thread_count_ == 0 && + threads_.size() + dynamic_threads_.size() < max_thread_count_) { + cleanup_finished_threads(); + dynamic_threads_.emplace_back(std::thread([this]() { worker(true); })); + } } cond_.notify_one(); @@ -5350,7 +5922,6 @@ bool ThreadPool::enqueue(std::function fn) { } void ThreadPool::shutdown() { - // Stop all worker threads... { std::unique_lock lock(mutex_); shutdown_ = true; @@ -5358,31 +5929,84 @@ void ThreadPool::shutdown() { cond_.notify_all(); - // Join... for (auto &t : threads_) { - t.join(); + if (t.joinable()) { t.join(); } + } + + // Move dynamic_threads_ to a local list under the lock to avoid racing + // with worker threads that call move_to_finished() concurrently. + std::list remaining_dynamic; + { + std::unique_lock lock(mutex_); + remaining_dynamic = std::move(dynamic_threads_); + } + for (auto &t : remaining_dynamic) { + if (t.joinable()) { t.join(); } + } + + std::unique_lock lock(mutex_); + cleanup_finished_threads(); +} + +void ThreadPool::move_to_finished(std::thread::id id) { + // Must be called with mutex_ held + for (auto it = dynamic_threads_.begin(); it != dynamic_threads_.end(); ++it) { + if (it->get_id() == id) { + finished_threads_.push_back(std::move(*it)); + dynamic_threads_.erase(it); + return; + } } } -ThreadPool::worker::worker(ThreadPool &pool) : pool_(pool) {} +void ThreadPool::cleanup_finished_threads() { + // Must be called with mutex_ held + for (auto &t : finished_threads_) { + if (t.joinable()) { t.join(); } + } + finished_threads_.clear(); +} -void ThreadPool::worker::operator()() { +void ThreadPool::worker(bool is_dynamic) { for (;;) { std::function fn; { - std::unique_lock lock(pool_.mutex_); + std::unique_lock lock(mutex_); + idle_thread_count_++; - pool_.cond_.wait(lock, - [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + if (is_dynamic) { + auto has_work = cond_.wait_for( + lock, std::chrono::seconds(CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT), + [&] { return !jobs_.empty() || shutdown_; }); + if (!has_work) { + // Timed out with no work - exit this dynamic thread + idle_thread_count_--; + move_to_finished(std::this_thread::get_id()); + break; + } + } else { + cond_.wait(lock, [&] { return !jobs_.empty() || shutdown_; }); + } - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + idle_thread_count_--; - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); + if (shutdown_ && jobs_.empty()) { break; } + + fn = std::move(jobs_.front()); + jobs_.pop_front(); } assert(true == static_cast(fn)); fn(); + + // Dynamic thread: exit if queue is empty after task completion + if (is_dynamic) { + std::unique_lock lock(mutex_); + if (jobs_.empty()) { + move_to_finished(std::this_thread::get_id()); + break; + } + } } #if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ @@ -5540,6 +6164,11 @@ time_t SocketStream::duration() const { .count(); } +void SocketStream::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + // Buffer stream implementation bool BufferStream::is_readable() const { return true; } @@ -5865,6 +6494,11 @@ time_t SSLSocketStream::duration() const { .count(); } +void SSLSocketStream::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + } // namespace detail #endif // CPPHTTPLIB_SSL_ENABLED @@ -5874,8 +6508,10 @@ time_t SSLSocketStream::duration() const { // HTTP server implementation Server::Server() - : new_task_queue( - [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { + : new_task_queue([] { + return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT, + CPPHTTPLIB_THREAD_POOL_MAX_COUNT); + }) { #ifndef _WIN32 signal(SIGPIPE, SIG_IGN); #endif @@ -5950,6 +6586,21 @@ Server &Server::Options(const std::string &pattern, Handler handler) { return *this; } +Server &Server::WebSocket(const std::string &pattern, + WebSocketHandler handler) { + websocket_handlers_.push_back( + {make_matcher(pattern), std::move(handler), nullptr}); + return *this; +} + +Server &Server::WebSocket(const std::string &pattern, + WebSocketHandler handler, + SubProtocolSelector sub_protocol_selector) { + websocket_handlers_.push_back({make_matcher(pattern), std::move(handler), + std::move(sub_protocol_selector)}); + return *this; +} + bool Server::set_base_dir(const std::string &dir, const std::string &mount_point) { return set_mount_point(mount_point, dir); @@ -7072,7 +7723,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr, int remote_port, const std::string &local_addr, int local_port, bool close_connection, bool &connection_closed, - const std::function &setup_request) { + const std::function &setup_request, + bool *websocket_upgraded) { std::array buf{}; detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); @@ -7175,6 +7827,77 @@ Server::process_request(Stream &strm, const std::string &remote_addr, return !detail::is_socket_alive(sock); }; + // WebSocket upgrade + // Check pre_routing_handler_ before upgrading so that authentication + // and other middleware can reject the request with an HTTP response + // (e.g., 401) before the protocol switches. + if (detail::is_websocket_upgrade(req)) { + if (pre_routing_handler_ && + pre_routing_handler_(req, res) == HandlerResponse::Handled) { + if (res.status == -1) { res.status = StatusCode::OK_200; } + return write_response(strm, close_connection, req, res); + } + // Find matching WebSocket handler + for (const auto &entry : websocket_handlers_) { + if (entry.matcher->match(req)) { + // Compute accept key + auto client_key = req.get_header_value("Sec-WebSocket-Key"); + auto accept_key = detail::websocket_accept_key(client_key); + + // Negotiate subprotocol + std::string selected_subprotocol; + if (entry.sub_protocol_selector) { + auto protocol_header = req.get_header_value("Sec-WebSocket-Protocol"); + if (!protocol_header.empty()) { + std::vector protocols; + std::istringstream iss(protocol_header); + std::string token; + while (std::getline(iss, token, ',')) { + // Trim whitespace + auto start = token.find_first_not_of(' '); + auto end = token.find_last_not_of(' '); + if (start != std::string::npos) { + protocols.push_back(token.substr(start, end - start + 1)); + } + } + selected_subprotocol = entry.sub_protocol_selector(protocols); + } + } + + // Send 101 Switching Protocols + std::string handshake_response = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + + accept_key + "\r\n"; + if (!selected_subprotocol.empty()) { + if (!detail::fields::is_field_value(selected_subprotocol)) { + return false; + } + handshake_response += + "Sec-WebSocket-Protocol: " + selected_subprotocol + "\r\n"; + } + handshake_response += "\r\n"; + if (strm.write(handshake_response.data(), handshake_response.size()) < + 0) { + return false; + } + + connection_closed = true; + if (websocket_upgraded) { *websocket_upgraded = true; } + + { + // Use WebSocket-specific read timeout instead of HTTP timeout + strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0); + ws::WebSocket ws(strm, req, true); + entry.handler(req, ws); + } + return true; + } + } + // No matching handler - fall through to 404 + } + // Routing auto routed = false; #ifdef CPPHTTPLIB_NO_EXCEPTIONS @@ -7271,6 +7994,7 @@ bool Server::process_and_close_socket(socket_t sock) { int local_port = 0; detail::get_local_ip_and_port(sock, local_addr, local_port); + bool websocket_upgraded = false; auto ret = detail::process_server_socket( svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, @@ -7278,7 +8002,7 @@ bool Server::process_and_close_socket(socket_t sock) { [&](Stream &strm, bool close_connection, bool &connection_closed) { return process_request(strm, remote_addr, remote_port, local_addr, local_port, close_connection, connection_closed, - nullptr); + nullptr, &websocket_upgraded); }); detail::shutdown_socket(sock); @@ -9019,8 +9743,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Post(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Post(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -9033,8 +9759,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Post(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Post(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -9212,8 +9940,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Put(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Put(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -9226,8 +9956,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Put(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Put(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -9407,8 +10139,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Patch(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Patch(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -9421,8 +10155,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Patch(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Patch(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -10579,9 +11315,9 @@ bool SSLServer::process_and_close_socket(socket_t sock) { // Use scope_exit to ensure cleanup on all paths (including exceptions) bool handshake_done = false; bool ret = false; + bool websocket_upgraded = false; auto cleanup = detail::scope_exit([&] { - // Shutdown gracefully if handshake succeeded and processing was successful - if (handshake_done) { shutdown(session, ret); } + if (handshake_done) { shutdown(session, !websocket_upgraded && ret); } free_session(session); detail::shutdown_socket(sock); detail::close_socket(sock); @@ -10621,9 +11357,10 @@ bool SSLServer::process_and_close_socket(socket_t sock) { read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, remote_addr, remote_port, local_addr, - local_port, close_connection, connection_closed, - [&](Request &req) { req.ssl = session; }); + return process_request( + strm, remote_addr, remote_port, local_addr, local_port, + close_connection, connection_closed, + [&](Request &req) { req.ssl = session; }, &websocket_upgraded); }); return ret; @@ -10929,11 +11666,11 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) { bool is_ip = detail::is_ip_address(host_); -#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT - // MbedTLS needs explicit verification mode (OpenSSL uses SSL_VERIFY_NONE - // by default and performs all verification post-handshake). +#if defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT) + // MbedTLS/wolfSSL need explicit verification mode (OpenSSL uses + // SSL_VERIFY_NONE by default and performs all verification post-handshake). // For IP addresses with verification enabled, use OPTIONAL mode since - // MbedTLS requires hostname for VERIFY_REQUIRED. + // these backends require hostname for strict verification. if (is_ip && server_certificate_verification_) { set_verify_client(ctx_, false); } else { @@ -11154,6 +11891,107 @@ VerifyCallback &get_mbedtls_verify_callback() { return callback; } +// Check if a string is an IPv4 address +bool is_ipv4_address(const std::string &str) { + int dots = 0; + for (char c : str) { + if (c == '.') { + dots++; + } else if (!isdigit(static_cast(c))) { + return false; + } + } + return dots == 3; +} + +// Parse IPv4 address string to bytes +bool parse_ipv4(const std::string &str, unsigned char *out) { + int parts[4]; + if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2], + &parts[3]) != 4) { + return false; + } + for (int i = 0; i < 4; i++) { + if (parts[i] < 0 || parts[i] > 255) return false; + out[i] = static_cast(parts[i]); + } + return true; +} + +#ifdef _WIN32 +// Enumerate Windows system certificates and call callback with DER data +template +bool enumerate_windows_system_certs(Callback cb) { + bool loaded = false; + static const wchar_t *store_names[] = {L"ROOT", L"CA"}; + for (auto store_name : store_names) { + HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name); + if (hStore) { + PCCERT_CONTEXT pContext = nullptr; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + if (cb(pContext->pbCertEncoded, pContext->cbCertEncoded)) { + loaded = true; + } + } + CertCloseStore(hStore, 0); + } + } + return loaded; +} +#endif + +#if defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +// Enumerate macOS Keychain certificates and call callback with DER data +template +bool enumerate_macos_keychain_certs(Callback cb) { + bool loaded = false; + CFArrayRef certs = nullptr; + OSStatus status = SecTrustCopyAnchorCertificates(&certs); + if (status == errSecSuccess && certs) { + CFIndex count = CFArrayGetCount(certs); + for (CFIndex i = 0; i < count; i++) { + SecCertificateRef cert = + (SecCertificateRef)CFArrayGetValueAtIndex(certs, i); + CFDataRef data = SecCertificateCopyData(cert); + if (data) { + if (cb(CFDataGetBytePtr(data), + static_cast(CFDataGetLength(data)))) { + loaded = true; + } + CFRelease(data); + } + } + CFRelease(certs); + } + return loaded; +} +#endif + +#if !defined(_WIN32) && !(defined(__APPLE__) && \ + defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)) +// Common CA certificate file paths on Linux/Unix +const char **system_ca_paths() { + static const char *paths[] = { + "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu + "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS + "/etc/ssl/ca-bundle.pem", // OpenSUSE + "/etc/pki/tls/cacert.pem", // OpenELEC + "/etc/ssl/cert.pem", // Alpine, FreeBSD + nullptr}; + return paths; +} + +// Common CA certificate directory paths on Linux/Unix +const char **system_ca_dirs() { + static const char *dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu + "/etc/pki/tls/certs", // RHEL/CentOS + "/usr/share/ca-certificates", // Other + nullptr}; + return dirs; +} +#endif + } // namespace impl bool set_client_ca_file(ctx_t ctx, const char *ca_file, @@ -12730,33 +13568,6 @@ int mbedtls_sni_callback(void *p_ctx, mbedtls_ssl_context *ssl, int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, int cert_depth, uint32_t *flags); -// Check if a string is an IPv4 address -bool is_ipv4_address(const std::string &str) { - int dots = 0; - for (char c : str) { - if (c == '.') { - dots++; - } else if (!isdigit(static_cast(c))) { - return false; - } - } - return dots == 3; -} - -// Parse IPv4 address string to bytes -bool parse_ipv4(const std::string &str, unsigned char *out) { - int parts[4]; - if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2], - &parts[3]) != 4) { - return false; - } - for (int i = 0; i < 4; i++) { - if (parts[i] < 0 || parts[i] > 255) return false; - out[i] = static_cast(parts[i]); - } - return true; -} - // MbedTLS verify callback wrapper int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, int cert_depth, uint32_t *flags) { @@ -12971,68 +13782,26 @@ bool load_system_certs(ctx_t ctx) { bool loaded = false; #ifdef _WIN32 - // Load from Windows certificate store (ROOT and CA) - static const wchar_t *store_names[] = {L"ROOT", L"CA"}; - for (auto store_name : store_names) { - HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name); - if (hStore) { - PCCERT_CONTEXT pContext = nullptr; - while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != - nullptr) { - int ret = mbedtls_x509_crt_parse_der( - &mctx->ca_chain, pContext->pbCertEncoded, pContext->cbCertEncoded); - if (ret == 0) { loaded = true; } - } - CertCloseStore(hStore, 0); - } - } + loaded = impl::enumerate_windows_system_certs( + [&](const unsigned char *data, size_t len) { + return mbedtls_x509_crt_parse_der(&mctx->ca_chain, data, len) == 0; + }); #elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) - // Load from macOS Keychain - CFArrayRef certs = nullptr; - OSStatus status = SecTrustCopyAnchorCertificates(&certs); - if (status == errSecSuccess && certs) { - CFIndex count = CFArrayGetCount(certs); - for (CFIndex i = 0; i < count; i++) { - SecCertificateRef cert = - (SecCertificateRef)CFArrayGetValueAtIndex(certs, i); - CFDataRef data = SecCertificateCopyData(cert); - if (data) { - int ret = mbedtls_x509_crt_parse_der( - &mctx->ca_chain, CFDataGetBytePtr(data), - static_cast(CFDataGetLength(data))); - if (ret == 0) { loaded = true; } - CFRelease(data); - } - } - CFRelease(certs); - } + loaded = impl::enumerate_macos_keychain_certs( + [&](const unsigned char *data, size_t len) { + return mbedtls_x509_crt_parse_der(&mctx->ca_chain, data, len) == 0; + }); #else - // Try common CA certificate locations on Linux/Unix - static const char *ca_paths[] = { - "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu - "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS - "/etc/ssl/ca-bundle.pem", // OpenSUSE - "/etc/pki/tls/cacert.pem", // OpenELEC - "/etc/ssl/cert.pem", // Alpine, FreeBSD - nullptr}; - - for (const char **path = ca_paths; *path; ++path) { - int ret = mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path); - if (ret >= 0) { + for (auto path = impl::system_ca_paths(); *path; ++path) { + if (mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path) >= 0) { loaded = true; break; } } - // Also try the CA directory if (!loaded) { - static const char *ca_dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu - "/etc/pki/tls/certs", // RHEL/CentOS - "/usr/share/ca-certificates", nullptr}; - - for (const char **dir = ca_dirs; *dir; ++dir) { - int ret = mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir); - if (ret >= 0) { + for (auto dir = impl::system_ca_dirs(); *dir; ++dir) { + if (mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir) >= 0) { loaded = true; break; } @@ -13083,6 +13852,18 @@ bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, return false; } + // Verify that the certificate and private key match +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); if (ret != 0) { impl::mbedtls_last_error() = ret; @@ -13116,6 +13897,18 @@ bool set_client_cert_file(ctx_t ctx, const char *cert_path, return false; } + // Verify that the certificate and private key match +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); if (ret != 0) { impl::mbedtls_last_error() = ret; @@ -13877,4 +14670,1477 @@ std::string verify_error_string(long error_code) { #endif // CPPHTTPLIB_MBEDTLS_SUPPORT +/* + * Group 10: TLS abstraction layer - wolfSSL backend + */ + +/* + * wolfSSL Backend Implementation + */ + +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +namespace tls { + +namespace impl { + +// wolfSSL session wrapper +struct WolfSSLSession { + WOLFSSL *ssl = nullptr; + socket_t sock = INVALID_SOCKET; + std::string hostname; // For client: set via set_sni + std::string sni_hostname; // For server: received from client via SNI callback + + WolfSSLSession() = default; + + ~WolfSSLSession() { + if (ssl) { wolfSSL_free(ssl); } + } + + WolfSSLSession(const WolfSSLSession &) = delete; + WolfSSLSession &operator=(const WolfSSLSession &) = delete; +}; + +// Thread-local error code accessor for wolfSSL +uint64_t &wolfssl_last_error() { + static thread_local uint64_t err = 0; + return err; +} + +// Helper to map wolfSSL error to ErrorCode. +// ssl_error is the value from wolfSSL_get_error(). +// raw_ret is the raw return value from the wolfSSL call (for low-level error). +ErrorCode map_wolfssl_error(WOLFSSL *ssl, int ssl_error, + int &out_errno) { + switch (ssl_error) { + case SSL_ERROR_NONE: return ErrorCode::Success; + case SSL_ERROR_WANT_READ: return ErrorCode::WantRead; + case SSL_ERROR_WANT_WRITE: return ErrorCode::WantWrite; + case SSL_ERROR_ZERO_RETURN: return ErrorCode::PeerClosed; + case SSL_ERROR_SYSCALL: out_errno = errno; return ErrorCode::SyscallError; + default: + if (ssl) { + // wolfSSL stores the low-level error code as a negative value. + // DOMAIN_NAME_MISMATCH (-322) indicates hostname verification failure. + int low_err = ssl_error; // wolfSSL_get_error returns the low-level code + if (low_err == DOMAIN_NAME_MISMATCH) { + return ErrorCode::HostnameMismatch; + } + // Check verify result to distinguish cert verification from generic SSL + // errors. + long vr = wolfSSL_get_verify_result(ssl); + if (vr != 0) { return ErrorCode::CertVerifyFailed; } + } + return ErrorCode::Fatal; + } +} + +// WolfSSLContext constructor/destructor implementations +WolfSSLContext::WolfSSLContext() { wolfSSL_Init(); } + +WolfSSLContext::~WolfSSLContext() { + if (ctx) { wolfSSL_CTX_free(ctx); } +} + +// Thread-local storage for SNI captured during handshake +std::string &wolfssl_pending_sni() { + static thread_local std::string sni; + return sni; +} + +// SNI callback for wolfSSL server to capture client's SNI hostname +int wolfssl_sni_callback(WOLFSSL *ssl, int *ret, void *exArg) { + (void)ret; + (void)exArg; + + void *name_data = nullptr; + unsigned short name_len = + wolfSSL_SNI_GetRequest(ssl, WOLFSSL_SNI_HOST_NAME, &name_data); + + if (name_data && name_len > 0) { + wolfssl_pending_sni().assign(static_cast(name_data), + name_len); + } else { + wolfssl_pending_sni().clear(); + } + return 0; // Continue regardless +} + +// wolfSSL verify callback wrapper +int wolfssl_verify_callback(int preverify_ok, + WOLFSSL_X509_STORE_CTX *x509_ctx) { + auto &callback = get_verify_callback(); + if (!callback) { return preverify_ok; } + + WOLFSSL_X509 *cert = wolfSSL_X509_STORE_CTX_get_current_cert(x509_ctx); + int depth = wolfSSL_X509_STORE_CTX_get_error_depth(x509_ctx); + int err = wolfSSL_X509_STORE_CTX_get_error(x509_ctx); + + // Get the WOLFSSL object from the X509_STORE_CTX + WOLFSSL *ssl = static_cast(wolfSSL_X509_STORE_CTX_get_ex_data( + x509_ctx, wolfSSL_get_ex_data_X509_STORE_CTX_idx())); + + VerifyContext verify_ctx; + verify_ctx.session = static_cast(ssl); + verify_ctx.cert = static_cast(cert); + verify_ctx.depth = depth; + verify_ctx.preverify_ok = (preverify_ok != 0); + verify_ctx.error_code = static_cast(err); + + if (err != 0) { + verify_ctx.error_string = wolfSSL_X509_verify_cert_error_string(err); + } else { + verify_ctx.error_string = nullptr; + } + + bool accepted = callback(verify_ctx); + return accepted ? 1 : 0; +} + +void set_wolfssl_password_cb(WOLFSSL_CTX *ctx, const char *password) { + wolfSSL_CTX_set_default_passwd_cb_userdata(ctx, const_cast(password)); + wolfSSL_CTX_set_default_passwd_cb( + ctx, [](char *buf, int size, int /*rwflag*/, void *userdata) -> int { + auto *pwd = static_cast(userdata); + if (!pwd) return 0; + auto len = static_cast(strlen(pwd)); + if (len > size) len = size; + memcpy(buf, pwd, static_cast(len)); + return len; + }); +} + +} // namespace impl + +ctx_t create_client_context() { + auto ctx = new (std::nothrow) impl::WolfSSLContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = false; + + WOLFSSL_METHOD *method = wolfTLSv1_2_client_method(); + if (!method) { + delete ctx; + return nullptr; + } + + ctx->ctx = wolfSSL_CTX_new(method); + if (!ctx->ctx) { + delete ctx; + return nullptr; + } + + // Default: verify peer certificate + wolfSSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_PEER, nullptr); + + return static_cast(ctx); +} + +ctx_t create_server_context() { + auto ctx = new (std::nothrow) impl::WolfSSLContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = true; + + WOLFSSL_METHOD *method = wolfTLSv1_2_server_method(); + if (!method) { + delete ctx; + return nullptr; + } + + ctx->ctx = wolfSSL_CTX_new(method); + if (!ctx->ctx) { + delete ctx; + return nullptr; + } + + // Default: don't verify client + wolfSSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_NONE, nullptr); + + // Enable SNI on server + wolfSSL_CTX_SNI_SetOptions(ctx->ctx, WOLFSSL_SNI_HOST_NAME, + WOLFSSL_SNI_CONTINUE_ON_MISMATCH); + wolfSSL_CTX_set_servername_callback(ctx->ctx, impl::wolfssl_sni_callback); + + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { delete static_cast(ctx); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) { return false; } + auto wctx = static_cast(ctx); + + int min_ver = WOLFSSL_TLSV1_2; + if (version >= Version::TLS1_3) { min_ver = WOLFSSL_TLSV1_3; } + + return wolfSSL_CTX_SetMinVersion(wctx->ctx, min_ver) == WOLFSSL_SUCCESS; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(pem), + static_cast(len), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + wctx->ca_pem_data_.append(pem, len); + return true; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_locations(wctx->ctx, file_path, nullptr); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + return true; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_locations(wctx->ctx, nullptr, dir_path); + // wolfSSL may fail if the directory doesn't contain properly hashed certs. + // Unlike OpenSSL which lazily loads certs from directories, wolfSSL scans + // immediately. Return true even on failure since the CA file may have + // already been loaded, matching OpenSSL's lenient behavior. + (void)ret; + return true; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) { return false; } + auto wctx = static_cast(ctx); + bool loaded = false; + +#ifdef _WIN32 + loaded = impl::enumerate_windows_system_certs( + [&](const unsigned char *data, size_t len) { + return wolfSSL_CTX_load_verify_buffer(wctx->ctx, data, + static_cast(len), + SSL_FILETYPE_ASN1) == SSL_SUCCESS; + }); +#elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) + loaded = impl::enumerate_macos_keychain_certs( + [&](const unsigned char *data, size_t len) { + return wolfSSL_CTX_load_verify_buffer(wctx->ctx, data, + static_cast(len), + SSL_FILETYPE_ASN1) == SSL_SUCCESS; + }); +#else + for (auto path = impl::system_ca_paths(); *path; ++path) { + if (wolfSSL_CTX_load_verify_locations(wctx->ctx, *path, nullptr) == + SSL_SUCCESS) { + loaded = true; + break; + } + } + + if (!loaded) { + for (auto dir = impl::system_ca_dirs(); *dir; ++dir) { + if (wolfSSL_CTX_load_verify_locations(wctx->ctx, nullptr, *dir) == + SSL_SUCCESS) { + loaded = true; + break; + } + } + } +#endif + + return loaded; +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) { return false; } + auto wctx = static_cast(ctx); + + // Load certificate + int ret = wolfSSL_CTX_use_certificate_buffer( + wctx->ctx, reinterpret_cast(cert), + static_cast(strlen(cert)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password callback if password is provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load private key + ret = wolfSSL_CTX_use_PrivateKey_buffer( + wctx->ctx, reinterpret_cast(key), + static_cast(strlen(key)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Verify that the certificate and private key match + return wolfSSL_CTX_check_private_key(wctx->ctx) == SSL_SUCCESS; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) { return false; } + auto wctx = static_cast(ctx); + + // Load certificate file + int ret = + wolfSSL_CTX_use_certificate_file(wctx->ctx, cert_path, SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password callback if password is provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load private key file + ret = wolfSSL_CTX_use_PrivateKey_file(wctx->ctx, key_path, SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Verify that the certificate and private key match + return wolfSSL_CTX_check_private_key(wctx->ctx) == SSL_SUCCESS; +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) { return; } + auto wctx = static_cast(ctx); + wctx->verify_client = require; + if (require) { + wolfSSL_CTX_set_verify( + wctx->ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + wctx->has_verify_callback ? impl::wolfssl_verify_callback : nullptr); + } else { + if (wctx->has_verify_callback) { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_PEER, + impl::wolfssl_verify_callback); + } else { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_NONE, nullptr); + } + } +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) { return nullptr; } + auto wctx = static_cast(ctx); + + auto session = new (std::nothrow) impl::WolfSSLSession(); + if (!session) { return nullptr; } + + session->sock = sock; + session->ssl = wolfSSL_new(wctx->ctx); + if (!session->ssl) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + delete session; + return nullptr; + } + + wolfSSL_set_fd(session->ssl, static_cast(sock)); + + return static_cast(session); +} + +void free_session(session_t session) { + if (session) { delete static_cast(session); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) { return false; } + auto wsession = static_cast(session); + + int ret = wolfSSL_UseSNI(wsession->ssl, WOLFSSL_SNI_HOST_NAME, hostname, + static_cast(strlen(hostname))); + if (ret != WOLFSSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Also set hostname for verification + wolfSSL_check_domain_name(wsession->ssl, hostname); + + wsession->hostname = hostname; + return true; +} + +bool set_hostname(session_t session, const char *hostname) { + // In wolfSSL, set_hostname also sets up hostname verification + return set_sni(session, hostname); +} + +TlsError connect(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_connect(wsession->ssl); + + if (ret == SSL_SUCCESS) { + err.code = ErrorCode::Success; + } else { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + } + + return err; +} + +TlsError accept(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_accept(wsession->ssl); + + if (ret == SSL_SUCCESS) { + err.code = ErrorCode::Success; + // Capture SNI from thread-local storage after successful handshake + wsession->sni_hostname = std::move(impl::wolfssl_pending_sni()); + impl::wolfssl_pending_sni().clear(); + } else { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + } + + return err; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto wsession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = wolfSSL_connect(wsession->ssl)) != SSL_SUCCESS) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ssl_error == SSL_ERROR_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // Error or timeout + if (err) { + err->code = + impl::map_wolfssl_error(wsession->ssl, ssl_error, err->sys_errno); + err->backend_code = static_cast(ssl_error); + } + impl::wolfssl_last_error() = static_cast(ssl_error); + return false; + } + + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto wsession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = wolfSSL_accept(wsession->ssl)) != SSL_SUCCESS) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ssl_error == SSL_ERROR_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // Error or timeout + if (err) { + err->code = + impl::map_wolfssl_error(wsession->ssl, ssl_error, err->sys_errno); + err->backend_code = static_cast(ssl_error); + } + impl::wolfssl_last_error() = static_cast(ssl_error); + return false; + } + + if (err) { err->code = ErrorCode::Success; } + + // Capture SNI from thread-local storage after successful handshake + wsession->sni_hostname = std::move(impl::wolfssl_pending_sni()); + impl::wolfssl_pending_sni().clear(); + + return true; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_read(wsession->ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_write(wsession->ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + // wolfSSL_write returns 0 when the peer has sent a close_notify. + // Treat this as an error (return -1) so callers don't spin in a + // write loop adding zero to the offset. + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return -1; + } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + return -1; +} + +int pending(const_session_t session) { + if (!session) { return 0; } + auto wsession = + static_cast(const_cast(session)); + return wolfSSL_pending(wsession->ssl); +} + +void shutdown(session_t session, bool graceful) { + if (!session) { return; } + auto wsession = static_cast(session); + + if (graceful) { + int ret; + int attempts = 0; + while ((ret = wolfSSL_shutdown(wsession->ssl)) != SSL_SUCCESS && + attempts < 3) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error != SSL_ERROR_WANT_READ && + ssl_error != SSL_ERROR_WANT_WRITE) { + break; + } + attempts++; + } + } else { + wolfSSL_shutdown(wsession->ssl); + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session || sock == INVALID_SOCKET) { return true; } + auto wsession = static_cast(session); + + // Check if there's already decrypted data available + if (wolfSSL_pending(wsession->ssl) > 0) { return false; } + + // Set socket to non-blocking to avoid blocking on read + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + // Peek 1 byte to check connection status without consuming data + unsigned char buf; + int ret = wolfSSL_peek(wsession->ssl, &buf, 1); + + // If we got data or WANT_READ (would block), connection is alive + if (ret > 0) { return false; } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { return false; } + + return ssl_error == SSL_ERROR_ZERO_RETURN || ssl_error == SSL_ERROR_SYSCALL || + ret == 0; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) { return nullptr; } + auto wsession = + static_cast(const_cast(session)); + + WOLFSSL_X509 *cert = wolfSSL_get_peer_certificate(wsession->ssl); + return static_cast(cert); +} + +void free_cert(cert_t cert) { + if (cert) { wolfSSL_X509_free(static_cast(cert)); } +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) { return false; } + auto x509 = static_cast(cert); + std::string host_str(hostname); + + // Check if hostname is an IP address + bool is_ip = impl::is_ipv4_address(host_str); + unsigned char ip_bytes[4]; + if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); } + + // Check Subject Alternative Names + auto *san_names = static_cast( + wolfSSL_X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + + if (san_names) { + int san_count = wolfSSL_sk_num(san_names); + for (int i = 0; i < san_count; i++) { + auto *names = + static_cast(wolfSSL_sk_value(san_names, i)); + if (!names) continue; + + if (!is_ip && names->type == WOLFSSL_GEN_DNS) { + // DNS name + unsigned char *dns_name = nullptr; + int dns_len = wolfSSL_ASN1_STRING_to_UTF8(&dns_name, names->d.dNSName); + if (dns_name && dns_len > 0) { + std::string san_name(reinterpret_cast(dns_name), + static_cast(dns_len)); + XFREE(dns_name, nullptr, DYNAMIC_TYPE_OPENSSL); + if (detail::match_hostname(san_name, host_str)) { + wolfSSL_sk_free(san_names); + return true; + } + } + } else if (is_ip && names->type == WOLFSSL_GEN_IPADD) { + // IP address + unsigned char *ip_data = wolfSSL_ASN1_STRING_data(names->d.iPAddress); + int ip_len = wolfSSL_ASN1_STRING_length(names->d.iPAddress); + if (ip_data && ip_len == 4 && memcmp(ip_data, ip_bytes, 4) == 0) { + wolfSSL_sk_free(san_names); + return true; + } + } + } + wolfSSL_sk_free(san_names); + } + + // Fallback: Check Common Name (CN) in subject + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (subject) { + char cn[256] = {}; + int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn, + sizeof(cn)); + if (cn_len > 0) { + std::string cn_str(cn, static_cast(cn_len)); + if (detail::match_hostname(cn_str, host_str)) { return true; } + } + } + + return false; +} + +uint64_t hostname_mismatch_code() { + return static_cast(DOMAIN_NAME_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) { return -1; } + auto wsession = + static_cast(const_cast(session)); + long result = wolfSSL_get_verify_result(wsession->ssl); + return result; +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (!subject) return ""; + + char cn[256] = {}; + int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn, + sizeof(cn)); + if (cn_len <= 0) return ""; + return std::string(cn, static_cast(cn_len)); +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_X509_NAME *issuer = wolfSSL_X509_get_issuer_name(x509); + if (!issuer) return ""; + + char *name_str = wolfSSL_X509_NAME_oneline(issuer, nullptr, 0); + if (!name_str) return ""; + + std::string result(name_str); + XFREE(name_str, nullptr, DYNAMIC_TYPE_OPENSSL); + return result; +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + auto *san_names = static_cast( + wolfSSL_X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + if (!san_names) return true; // No SANs is not an error + + int count = wolfSSL_sk_num(san_names); + for (int i = 0; i < count; i++) { + auto *name = + static_cast(wolfSSL_sk_value(san_names, i)); + if (!name) continue; + + SanEntry entry; + switch (name->type) { + case WOLFSSL_GEN_DNS: { + entry.type = SanType::DNS; + unsigned char *dns_name = nullptr; + int dns_len = wolfSSL_ASN1_STRING_to_UTF8(&dns_name, name->d.dNSName); + if (dns_name && dns_len > 0) { + entry.value = std::string(reinterpret_cast(dns_name), + static_cast(dns_len)); + XFREE(dns_name, nullptr, DYNAMIC_TYPE_OPENSSL); + } + break; + } + case WOLFSSL_GEN_IPADD: { + entry.type = SanType::IP; + unsigned char *ip_data = wolfSSL_ASN1_STRING_data(name->d.iPAddress); + int ip_len = wolfSSL_ASN1_STRING_length(name->d.iPAddress); + if (ip_data && ip_len == 4) { + char buf[16]; + snprintf(buf, sizeof(buf), "%d.%d.%d.%d", ip_data[0], ip_data[1], + ip_data[2], ip_data[3]); + entry.value = buf; + } else if (ip_data && ip_len == 16) { + char buf[64]; + snprintf(buf, sizeof(buf), + "%02x%02x:%02x%02x:%02x%02x:%02x%02x:" + "%02x%02x:%02x%02x:%02x%02x:%02x%02x", + ip_data[0], ip_data[1], ip_data[2], ip_data[3], ip_data[4], + ip_data[5], ip_data[6], ip_data[7], ip_data[8], ip_data[9], + ip_data[10], ip_data[11], ip_data[12], ip_data[13], + ip_data[14], ip_data[15]); + entry.value = buf; + } + break; + } + case WOLFSSL_GEN_EMAIL: + entry.type = SanType::EMAIL; + { + unsigned char *email = nullptr; + int email_len = wolfSSL_ASN1_STRING_to_UTF8(&email, name->d.rfc822Name); + if (email && email_len > 0) { + entry.value = std::string(reinterpret_cast(email), + static_cast(email_len)); + XFREE(email, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + break; + case WOLFSSL_GEN_URI: + entry.type = SanType::URI; + { + unsigned char *uri = nullptr; + int uri_len = wolfSSL_ASN1_STRING_to_UTF8( + &uri, name->d.uniformResourceIdentifier); + if (uri && uri_len > 0) { + entry.value = std::string(reinterpret_cast(uri), + static_cast(uri_len)); + XFREE(uri, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + wolfSSL_sk_free(san_names); + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + const WOLFSSL_ASN1_TIME *nb = wolfSSL_X509_get_notBefore(x509); + const WOLFSSL_ASN1_TIME *na = wolfSSL_X509_get_notAfter(x509); + + if (!nb || !na) return false; + + // wolfSSL_ASN1_TIME_to_tm is available + struct tm tm_nb = {}, tm_na = {}; + if (wolfSSL_ASN1_TIME_to_tm(nb, &tm_nb) != WOLFSSL_SUCCESS) return false; + if (wolfSSL_ASN1_TIME_to_tm(na, &tm_na) != WOLFSSL_SUCCESS) return false; + +#ifdef _WIN32 + not_before = _mkgmtime(&tm_nb); + not_after = _mkgmtime(&tm_na); +#else + not_before = timegm(&tm_nb); + not_after = timegm(&tm_na); +#endif + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_ASN1_INTEGER *serial_asn1 = wolfSSL_X509_get_serialNumber(x509); + if (!serial_asn1) return ""; + + // Get the serial number data + int len = serial_asn1->length; + unsigned char *data = serial_asn1->data; + if (!data || len <= 0) return ""; + + std::string result; + result.reserve(static_cast(len) * 2); + for (int i = 0; i < len; i++) { + char hex[3]; + snprintf(hex, sizeof(hex), "%02X", data[i]); + result += hex; + } + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto x509 = static_cast(cert); + + int der_len = 0; + const unsigned char *der_data = wolfSSL_X509_get_der(x509, &der_len); + if (!der_data || der_len <= 0) return false; + + der.assign(der_data, der_data + der_len); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto wsession = static_cast(session); + + // For server: return SNI received from client during handshake + if (!wsession->sni_hostname.empty()) { + return wsession->sni_hostname.c_str(); + } + + // For client: return the hostname set via set_sni + if (!wsession->hostname.empty()) { return wsession->hostname.c_str(); } + + return nullptr; +} + +uint64_t peek_error() { + return static_cast(wolfSSL_ERR_peek_last_error()); +} + +uint64_t get_error() { + uint64_t err = impl::wolfssl_last_error(); + impl::wolfssl_last_error() = 0; + return err; +} + +std::string error_string(uint64_t code) { + char buf[256]; + wolfSSL_ERR_error_string(static_cast(code), buf); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + if (!pem || len == 0) { return nullptr; } + // Validate by attempting to load into a temporary ctx + WOLFSSL_CTX *tmp_ctx = wolfSSL_CTX_new(wolfTLSv1_2_client_method()); + if (!tmp_ctx) { return nullptr; } + int ret = wolfSSL_CTX_load_verify_buffer( + tmp_ctx, reinterpret_cast(pem), + static_cast(len), SSL_FILETYPE_PEM); + wolfSSL_CTX_free(tmp_ctx); + if (ret != SSL_SUCCESS) { return nullptr; } + return static_cast( + new impl::WolfSSLCAStore{std::string(pem, len)}); +} + +void free_ca_store(ca_store_t store) { + delete static_cast(store); +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto *wctx = static_cast(ctx); + auto *ca = static_cast(store); + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(ca->pem_data.data()), + static_cast(ca->pem_data.size()), SSL_FILETYPE_PEM); + if (ret == SSL_SUCCESS) { wctx->ca_pem_data_ += ca->pem_data; } + return ret == SSL_SUCCESS; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto *wctx = static_cast(ctx); + if (wctx->ca_pem_data_.empty()) { return 0; } + + const std::string &pem = wctx->ca_pem_data_; + const std::string begin_marker = "-----BEGIN CERTIFICATE-----"; + const std::string end_marker = "-----END CERTIFICATE-----"; + size_t pos = 0; + while ((pos = pem.find(begin_marker, pos)) != std::string::npos) { + size_t end_pos = pem.find(end_marker, pos); + if (end_pos == std::string::npos) { break; } + end_pos += end_marker.size(); + std::string cert_pem = pem.substr(pos, end_pos - pos); + WOLFSSL_X509 *x509 = wolfSSL_X509_load_certificate_buffer( + reinterpret_cast(cert_pem.data()), + static_cast(cert_pem.size()), WOLFSSL_FILETYPE_PEM); + if (x509) { certs.push_back(static_cast(x509)); } + pos = end_pos; + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto *wctx = static_cast(ctx); + if (wctx->ca_pem_data_.empty()) { return names; } + + const std::string &pem = wctx->ca_pem_data_; + const std::string begin_marker = "-----BEGIN CERTIFICATE-----"; + const std::string end_marker = "-----END CERTIFICATE-----"; + size_t pos = 0; + while ((pos = pem.find(begin_marker, pos)) != std::string::npos) { + size_t end_pos = pem.find(end_marker, pos); + if (end_pos == std::string::npos) { break; } + end_pos += end_marker.size(); + std::string cert_pem = pem.substr(pos, end_pos - pos); + WOLFSSL_X509 *x509 = wolfSSL_X509_load_certificate_buffer( + reinterpret_cast(cert_pem.data()), + static_cast(cert_pem.size()), WOLFSSL_FILETYPE_PEM); + if (x509) { + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (subject) { + char *name_str = wolfSSL_X509_NAME_oneline(subject, nullptr, 0); + if (name_str) { + names.push_back(name_str); + XFREE(name_str, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + wolfSSL_X509_free(x509); + } + pos = end_pos; + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto *wctx = static_cast(ctx); + + // Load new certificate + int ret = wolfSSL_CTX_use_certificate_buffer( + wctx->ctx, reinterpret_cast(cert_pem), + static_cast(strlen(cert_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password if provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load new private key + ret = wolfSSL_CTX_use_PrivateKey_buffer( + wctx->ctx, reinterpret_cast(key_pem), + static_cast(strlen(key_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + return true; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto *wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(ca_pem), + static_cast(strlen(ca_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto *wctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + wctx->has_verify_callback = static_cast(impl::get_verify_callback()); + + if (wctx->has_verify_callback) { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_PEER, + impl::wolfssl_verify_callback); + } else { + wolfSSL_CTX_set_verify( + wctx->ctx, + wctx->verify_client + ? (SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT) + : SSL_VERIFY_NONE, + nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto *wsession = + static_cast(const_cast(session)); + return wolfSSL_get_verify_result(wsession->ssl); +} + +std::string verify_error_string(long error_code) { + if (error_code == 0) { return ""; } + const char *str = + wolfSSL_X509_verify_cert_error_string(static_cast(error_code)); + return str ? std::string(str) : std::string(); +} + +} // namespace tls + +#endif // CPPHTTPLIB_WOLFSSL_SUPPORT + +// WebSocket implementation +namespace ws { + +bool WebSocket::send_frame(Opcode op, const char *data, size_t len, + bool fin) { + std::lock_guard lock(write_mutex_); + if (closed_) { return false; } + return detail::write_websocket_frame(strm_, op, data, len, fin, !is_server_); +} + +ReadResult WebSocket::read(std::string &msg) { + while (!closed_) { + Opcode opcode; + std::string payload; + bool fin; + + if (!impl::read_websocket_frame(strm_, opcode, payload, fin, is_server_, + CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH)) { + closed_ = true; + return Fail; + } + + switch (opcode) { + case Opcode::Ping: { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Pong, payload.data(), + payload.size(), true, !is_server_); + continue; + } + case Opcode::Pong: continue; + case Opcode::Close: { + if (!closed_.exchange(true)) { + // Echo close frame back + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Close, payload.data(), + payload.size(), true, !is_server_); + } + return Fail; + } + case Opcode::Text: + case Opcode::Binary: { + auto result = opcode == Opcode::Text ? Text : Binary; + msg = std::move(payload); + + // Handle fragmentation + if (!fin) { + while (true) { + Opcode cont_opcode; + std::string cont_payload; + bool cont_fin; + if (!impl::read_websocket_frame( + strm_, cont_opcode, cont_payload, cont_fin, is_server_, + CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH)) { + closed_ = true; + return Fail; + } + if (cont_opcode == Opcode::Ping) { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame( + strm_, Opcode::Pong, cont_payload.data(), cont_payload.size(), + true, !is_server_); + continue; + } + if (cont_opcode == Opcode::Pong) { continue; } + if (cont_opcode == Opcode::Close) { + if (!closed_.exchange(true)) { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame( + strm_, Opcode::Close, cont_payload.data(), + cont_payload.size(), true, !is_server_); + } + return Fail; + } + // RFC 6455: continuation frames must use opcode 0x0 + if (cont_opcode != Opcode::Continuation) { + closed_ = true; + return Fail; + } + msg += cont_payload; + if (msg.size() > CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH) { + closed_ = true; + return Fail; + } + if (cont_fin) { break; } + } + } + // RFC 6455 Section 5.6: text frames must contain valid UTF-8 + if (result == Text && !impl::is_valid_utf8(msg)) { + close(CloseStatus::InvalidPayload, "invalid UTF-8"); + return Fail; + } + return result; + } + default: closed_ = true; return Fail; + } + } + return Fail; +} + +bool WebSocket::send(const std::string &data) { + return send_frame(Opcode::Text, data.data(), data.size()); +} + +bool WebSocket::send(const char *data, size_t len) { + return send_frame(Opcode::Binary, data, len); +} + +void WebSocket::close(CloseStatus status, const std::string &reason) { + if (closed_.exchange(true)) { return; } + ping_cv_.notify_all(); + std::string payload; + auto code = static_cast(status); + payload.push_back(static_cast((code >> 8) & 0xFF)); + payload.push_back(static_cast(code & 0xFF)); + // RFC 6455 Section 5.5: control frame payload must not exceed 125 bytes + // Close frame has 2-byte status code, so reason is limited to 123 bytes + payload += reason.substr(0, 123); + { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Close, payload.data(), + payload.size(), true, !is_server_); + } + + // RFC 6455 Section 7.1.1: after sending a Close frame, wait for the peer's + // Close response before closing the TCP connection. Use a short timeout to + // avoid hanging if the peer doesn't respond. + strm_.set_read_timeout(CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND, 0); + Opcode op; + std::string resp; + bool fin; + while (impl::read_websocket_frame(strm_, op, resp, fin, is_server_, 125)) { + if (op == Opcode::Close) { break; } + } +} + +WebSocket::~WebSocket() { + { + std::lock_guard lock(ping_mutex_); + closed_ = true; + } + ping_cv_.notify_all(); + if (ping_thread_.joinable()) { ping_thread_.join(); } +} + +void WebSocket::start_heartbeat() { + ping_thread_ = std::thread([this]() { + std::unique_lock lock(ping_mutex_); + while (!closed_) { + ping_cv_.wait_for(lock, std::chrono::seconds( + CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)); + if (closed_) { break; } + lock.unlock(); + if (!send_frame(Opcode::Ping, nullptr, 0)) { + closed_ = true; + break; + } + lock.lock(); + } + }); +} + +const Request &WebSocket::request() const { return req_; } + +bool WebSocket::is_open() const { return !closed_; } + +// WebSocketClient implementation +WebSocketClient::WebSocketClient( + const std::string &scheme_host_port_path, const Headers &headers) + : headers_(headers) { + const static std::regex re( + R"(([a-z]+):\/\/(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?(\/.*))"); + + std::smatch m; + if (std::regex_match(scheme_host_port_path, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_SSL_ENABLED + if (scheme != "ws" && scheme != "wss") { +#else + if (scheme != "ws") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "wss"; + + host_ = m[2].str(); + if (host_.empty()) { host_ = m[3].str(); } + + auto port_str = m[4].str(); + port_ = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + path_ = m[5].str(); + +#ifdef CPPHTTPLIB_SSL_ENABLED + is_ssl_ = is_ssl; +#else + if (is_ssl) { return; } +#endif + + is_valid_ = true; + } +} + +WebSocketClient::~WebSocketClient() { shutdown_and_close(); } + +bool WebSocketClient::is_valid() const { return is_valid_; } + +void WebSocketClient::shutdown_and_close() { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl_) { + if (tls_session_) { + tls::shutdown(tls_session_, true); + tls::free_session(tls_session_); + tls_session_ = nullptr; + } + if (tls_ctx_) { + tls::free_context(tls_ctx_); + tls_ctx_ = nullptr; + } + } +#endif + if (ws_ && ws_->is_open()) { ws_->close(); } + ws_.reset(); + if (sock_ != INVALID_SOCKET) { + detail::shutdown_socket(sock_); + detail::close_socket(sock_); + sock_ = INVALID_SOCKET; + } +} + +bool WebSocketClient::create_stream(std::unique_ptr &strm) { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl_) { + if (!detail::setup_client_tls_session( + host_, tls_ctx_, tls_session_, sock_, + server_certificate_verification_, ca_cert_file_path_, + ca_cert_store_, read_timeout_sec_, read_timeout_usec_)) { + return false; + } + + strm = std::unique_ptr(new detail::SSLSocketStream( + sock_, tls_session_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_)); + return true; + } +#endif + strm = std::unique_ptr( + new detail::SocketStream(sock_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_)); + return true; +} + +bool WebSocketClient::connect() { + if (!is_valid_) { return false; } + shutdown_and_close(); + + Error error; + sock_ = detail::create_client_socket( + host_, std::string(), port_, AF_UNSPEC, false, false, nullptr, 5, 0, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, std::string(), error); + + if (sock_ == INVALID_SOCKET) { return false; } + + std::unique_ptr strm; + if (!create_stream(strm)) { + shutdown_and_close(); + return false; + } + + std::string selected_subprotocol; + if (!detail::perform_websocket_handshake(*strm, host_, port_, path_, headers_, + selected_subprotocol)) { + shutdown_and_close(); + return false; + } + subprotocol_ = std::move(selected_subprotocol); + + Request req; + req.method = "GET"; + req.path = path_; + ws_ = std::unique_ptr(new WebSocket(std::move(strm), req, false)); + return true; +} + +ReadResult WebSocketClient::read(std::string &msg) { + if (!ws_) { return Fail; } + return ws_->read(msg); +} + +bool WebSocketClient::send(const std::string &data) { + if (!ws_) { return false; } + return ws_->send(data); +} + +bool WebSocketClient::send(const char *data, size_t len) { + if (!ws_) { return false; } + return ws_->send(data, len); +} + +void WebSocketClient::close(CloseStatus status, + const std::string &reason) { + if (ws_) { ws_->close(status, reason); } +} + +bool WebSocketClient::is_open() const { return ws_ && ws_->is_open(); } + +const std::string &WebSocketClient::subprotocol() const { + return subprotocol_; +} + +void WebSocketClient::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +void WebSocketClient::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +#ifdef CPPHTTPLIB_SSL_ENABLED + +void WebSocketClient::set_ca_cert_path(const std::string &path) { + ca_cert_file_path_ = path; +} + +void WebSocketClient::set_ca_cert_store(tls::ca_store_t store) { + ca_cert_store_ = store; +} + +void +WebSocketClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +#endif // CPPHTTPLIB_SSL_ENABLED + +} // namespace ws + } // namespace httplib diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index f7563283e..f2e3b6925 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.32.0" -#define CPPHTTPLIB_VERSION_NUM "0x002000" +#define CPPHTTPLIB_VERSION "0.33.1" +#define CPPHTTPLIB_VERSION_NUM "0x002101" /* * Platform compatibility check @@ -185,6 +185,14 @@ : 0)) #endif +#ifndef CPPHTTPLIB_THREAD_POOL_MAX_COUNT +#define CPPHTTPLIB_THREAD_POOL_MAX_COUNT (CPPHTTPLIB_THREAD_POOL_COUNT * 4) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT +#define CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT 3 // seconds +#endif + #ifndef CPPHTTPLIB_RECV_FLAGS #define CPPHTTPLIB_RECV_FLAGS 0 #endif @@ -201,6 +209,22 @@ #define CPPHTTPLIB_MAX_LINE_LENGTH 32768 #endif +#ifndef CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH +#define CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH 16777216 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND +#define CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND +#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30 +#endif + /* * Headers */ @@ -310,6 +334,7 @@ using socket_t = int; #include #include #include +#include #include #include #include @@ -328,6 +353,9 @@ using socket_t = int; #include #include #include +#if __cplusplus >= 201703L +#include +#endif #if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \ defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) @@ -415,10 +443,46 @@ using socket_t = int; #endif // CPPHTTPLIB_MBEDTLS_SUPPORT +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +#include + +#include + +// Fallback definitions for older wolfSSL versions (e.g., 5.6.6) +#ifndef WOLFSSL_GEN_EMAIL +#define WOLFSSL_GEN_EMAIL 1 +#endif +#ifndef WOLFSSL_GEN_DNS +#define WOLFSSL_GEN_DNS 2 +#endif +#ifndef WOLFSSL_GEN_URI +#define WOLFSSL_GEN_URI 6 +#endif +#ifndef WOLFSSL_GEN_IPADD +#define WOLFSSL_GEN_IPADD 7 +#endif + +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#endif // _WIN32 +#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#if TARGET_OS_MAC +#include +#endif +#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#endif // CPPHTTPLIB_WOLFSSL_SUPPORT + // Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available -// This simplifies conditional compilation when adding new backends (e.g., -// wolfSSL) -#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || defined(CPPHTTPLIB_MBEDTLS_SUPPORT) +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \ + defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT) #define CPPHTTPLIB_SSL_ENABLED #endif @@ -440,6 +504,10 @@ using socket_t = int; */ namespace httplib { +namespace ws { +class WebSocket; +} // namespace ws + namespace detail { /* @@ -711,6 +779,143 @@ using Match = std::smatch; using DownloadProgress = std::function; using UploadProgress = std::function; + +#if __cplusplus >= 201703L + +using any = std::any; +using bad_any_cast = std::bad_any_cast; + +template T any_cast(const any &a) { return std::any_cast(a); } +template T any_cast(any &a) { return std::any_cast(a); } +template T any_cast(any &&a) { + return std::any_cast(std::move(a)); +} +template const T *any_cast(const any *a) noexcept { + return std::any_cast(a); +} +template T *any_cast(any *a) noexcept { + return std::any_cast(a); +} + +#else // C++11/14 implementation + +class bad_any_cast : public std::bad_cast { +public: + const char *what() const noexcept override { return "bad any_cast"; } +}; + +namespace detail { + +using any_type_id = const void *; + +// Returns a unique per-type ID without RTTI. +// The static address is stable across TUs because function templates are +// implicitly inline and the ODR merges their statics into one. +template any_type_id any_typeid() noexcept { + static const char id = 0; + return &id; +} + +struct any_storage { + virtual ~any_storage() = default; + virtual std::unique_ptr clone() const = 0; + virtual any_type_id type_id() const noexcept = 0; +}; + +template struct any_value final : any_storage { + T value; + template explicit any_value(U &&v) : value(std::forward(v)) {} + std::unique_ptr clone() const override { + return std::unique_ptr(new any_value(value)); + } + any_type_id type_id() const noexcept override { return any_typeid(); } +}; + +} // namespace detail + +class any { + std::unique_ptr storage_; + +public: + any() noexcept = default; + any(const any &o) : storage_(o.storage_ ? o.storage_->clone() : nullptr) {} + any(any &&) noexcept = default; + any &operator=(const any &o) { + storage_ = o.storage_ ? o.storage_->clone() : nullptr; + return *this; + } + any &operator=(any &&) noexcept = default; + + template < + typename T, typename D = typename std::decay::type, + typename std::enable_if::value, int>::type = 0> + any(T &&v) : storage_(new detail::any_value(std::forward(v))) {} + + template < + typename T, typename D = typename std::decay::type, + typename std::enable_if::value, int>::type = 0> + any &operator=(T &&v) { + storage_.reset(new detail::any_value(std::forward(v))); + return *this; + } + + bool has_value() const noexcept { return storage_ != nullptr; } + void reset() noexcept { storage_.reset(); } + + template friend T *any_cast(any *a) noexcept; + template friend const T *any_cast(const any *a) noexcept; +}; + +template T *any_cast(any *a) noexcept { + if (!a || !a->storage_) { return nullptr; } + if (a->storage_->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(a->storage_.get())->value; +} + +template const T *any_cast(const any *a) noexcept { + if (!a || !a->storage_) { return nullptr; } + if (a->storage_->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(a->storage_.get())->value; +} + +template T any_cast(const any &a) { + using U = + typename std::remove_cv::type>::type; + const U *p = any_cast(&a); +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (!p) { throw bad_any_cast{}; } +#else + if (!p) { std::abort(); } +#endif + return static_cast(*p); +} + +template T any_cast(any &a) { + using U = + typename std::remove_cv::type>::type; + U *p = any_cast(&a); +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (!p) { throw bad_any_cast{}; } +#else + if (!p) { std::abort(); } +#endif + return static_cast(*p); +} + +template T any_cast(any &&a) { + using U = + typename std::remove_cv::type>::type; + U *p = any_cast(&a); +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (!p) { throw bad_any_cast{}; } +#else + if (!p) { std::abort(); } +#endif + return static_cast(std::move(*p)); +} + +#endif // __cplusplus >= 201703L + struct Response; using ResponseHandler = std::function; @@ -805,6 +1010,34 @@ struct FormDataProvider { }; using FormDataProviderItems = std::vector; +inline FormDataProvider +make_file_provider(const std::string &name, const std::string &filepath, + const std::string &filename = std::string(), + const std::string &content_type = std::string()) { + FormDataProvider fdp; + fdp.name = name; + fdp.filename = filename.empty() ? filepath : filename; + fdp.content_type = content_type; + fdp.provider = [filepath](size_t offset, DataSink &sink) -> bool { + std::ifstream f(filepath, std::ios::binary); + if (!f) { return false; } + if (offset > 0) { + f.seekg(static_cast(offset)); + if (!f.good()) { + sink.done(); + return true; + } + } + char buf[8192]; + f.read(buf, sizeof(buf)); + auto n = static_cast(f.gcount()); + if (n > 0) { return sink.write(buf, n); } + sink.done(); // EOF + return true; + }; + return fdp; +} + using ContentReceiverWithProgress = std::function; @@ -1010,6 +1243,10 @@ struct Response { std::string body; std::string location; // Redirect location + // User-defined context — set by pre-routing/pre-request handlers and read + // by route handlers to pass arbitrary data (e.g. decoded auth tokens). + std::map user_data; + bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", size_t id = 0) const; @@ -1124,6 +1361,11 @@ public: virtual time_t duration() const = 0; + virtual void set_read_timeout(time_t sec, time_t usec = 0) { + (void)sec; + (void)usec; + } + ssize_t write(const char *ptr); ssize_t write(const std::string &s); @@ -1146,7 +1388,7 @@ public: class ThreadPool final : public TaskQueue { public: - explicit ThreadPool(size_t n, size_t mqr = 0); + explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0); ThreadPool(const ThreadPool &) = delete; ~ThreadPool() override = default; @@ -1154,20 +1396,22 @@ public: void shutdown() override; private: - struct worker { - explicit worker(ThreadPool &pool); + void worker(bool is_dynamic); + void move_to_finished(std::thread::id id); + void cleanup_finished_threads(); - void operator()(); - - ThreadPool &pool_; - }; - friend struct worker; - - std::vector threads_; - std::list> jobs_; + size_t base_thread_count_; + size_t max_thread_count_; + size_t max_queued_requests_; + size_t idle_thread_count_; bool shutdown_; - size_t max_queued_requests_ = 0; + + std::list> jobs_; + std::vector threads_; // base threads + std::list dynamic_threads_; // dynamic threads + std::vector + finished_threads_; // exited dynamic threads awaiting join std::condition_variable cond_; std::mutex mutex_; @@ -1294,6 +1538,11 @@ public: using Expect100ContinueHandler = std::function; + using WebSocketHandler = + std::function; + using SubProtocolSelector = + std::function &protocols)>; + Server(); virtual ~Server(); @@ -1311,6 +1560,10 @@ public: Server &Delete(const std::string &pattern, HandlerWithContentReader handler); Server &Options(const std::string &pattern, Handler handler); + Server &WebSocket(const std::string &pattern, WebSocketHandler handler); + Server &WebSocket(const std::string &pattern, WebSocketHandler handler, + SubProtocolSelector sub_protocol_selector); + bool set_base_dir(const std::string &dir, const std::string &mount_point = std::string()); bool set_mount_point(const std::string &mount_point, const std::string &dir, @@ -1386,7 +1639,8 @@ protected: int remote_port, const std::string &local_addr, int local_port, bool close_connection, bool &connection_closed, - const std::function &setup_request); + const std::function &setup_request, + bool *websocket_upgraded = nullptr); std::atomic svr_sock_{INVALID_SOCKET}; @@ -1488,6 +1742,14 @@ private: HandlersForContentReader delete_handlers_for_content_reader_; Handlers options_handlers_; + struct WebSocketHandlerEntry { + std::unique_ptr matcher; + WebSocketHandler handler; + SubProtocolSelector sub_protocol_selector; + }; + using WebSocketHandlers = std::vector; + WebSocketHandlers websocket_handlers_; + HandlerWithResponse error_handler_; ExceptionHandler exception_handler_; HandlerWithResponse pre_routing_handler_; @@ -2970,6 +3232,36 @@ struct MbedTlsContext { } // namespace tls #endif +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +namespace tls { +namespace impl { + +// wolfSSL context wrapper (holds WOLFSSL_CTX and related state). +// This struct is accessible via tls::impl for use in SSL context +// setup callbacks (cast ctx_t to tls::impl::WolfSSLContext*). +struct WolfSSLContext { + WOLFSSL_CTX *ctx = nullptr; + bool is_server = false; + bool verify_client = false; + bool has_verify_callback = false; + std::string ca_pem_data_; // accumulated PEM for get_ca_names/get_ca_certs + + WolfSSLContext(); + ~WolfSSLContext(); + + WolfSSLContext(const WolfSSLContext &) = delete; + WolfSSLContext &operator=(const WolfSSLContext &) = delete; +}; + +// CA store for wolfSSL: holds raw PEM bytes to allow reloading into any ctx +struct WolfSSLCAStore { + std::string pem_data; +}; + +} // namespace impl +} // namespace tls +#endif + #endif // CPPHTTPLIB_SSL_ENABLED namespace stream { @@ -3335,6 +3627,143 @@ private: } // namespace sse +namespace ws { + +enum class Opcode : uint8_t { + Continuation = 0x0, + Text = 0x1, + Binary = 0x2, + Close = 0x8, + Ping = 0x9, + Pong = 0xA, +}; + +enum class CloseStatus : uint16_t { + Normal = 1000, + GoingAway = 1001, + ProtocolError = 1002, + UnsupportedData = 1003, + NoStatus = 1005, + Abnormal = 1006, + InvalidPayload = 1007, + PolicyViolation = 1008, + MessageTooBig = 1009, + MandatoryExtension = 1010, + InternalError = 1011, +}; + +enum ReadResult : int { Fail = 0, Text = 1, Binary = 2 }; + +class WebSocket { +public: + WebSocket(const WebSocket &) = delete; + WebSocket &operator=(const WebSocket &) = delete; + ~WebSocket(); + + ReadResult read(std::string &msg); + bool send(const std::string &data); + bool send(const char *data, size_t len); + void close(CloseStatus status = CloseStatus::Normal, + const std::string &reason = ""); + const Request &request() const; + bool is_open() const; + +private: + friend class httplib::Server; + friend class WebSocketClient; + + WebSocket(Stream &strm, const Request &req, bool is_server) + : strm_(strm), req_(req), is_server_(is_server) { + start_heartbeat(); + } + + WebSocket(std::unique_ptr &&owned_strm, const Request &req, + bool is_server) + : strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req), + is_server_(is_server) { + start_heartbeat(); + } + + void start_heartbeat(); + bool send_frame(Opcode op, const char *data, size_t len, bool fin = true); + + Stream &strm_; + std::unique_ptr owned_strm_; + Request req_; + bool is_server_; + std::atomic closed_{false}; + std::mutex write_mutex_; + std::thread ping_thread_; + std::mutex ping_mutex_; + std::condition_variable ping_cv_; +}; + +class WebSocketClient { +public: + explicit WebSocketClient(const std::string &scheme_host_port_path, + const Headers &headers = {}); + + ~WebSocketClient(); + WebSocketClient(const WebSocketClient &) = delete; + WebSocketClient &operator=(const WebSocketClient &) = delete; + + bool is_valid() const; + + bool connect(); + ReadResult read(std::string &msg); + bool send(const std::string &data); + bool send(const char *data, size_t len); + void close(CloseStatus status = CloseStatus::Normal, + const std::string &reason = ""); + bool is_open() const; + const std::string &subprotocol() const; + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); + +#ifdef CPPHTTPLIB_SSL_ENABLED + void set_ca_cert_path(const std::string &path); + void set_ca_cert_store(tls::ca_store_t store); + void enable_server_certificate_verification(bool enabled); +#endif + +private: + void shutdown_and_close(); + bool create_stream(std::unique_ptr &strm); + + std::string host_; + int port_; + std::string path_; + Headers headers_; + std::string subprotocol_; + bool is_valid_ = false; + socket_t sock_ = INVALID_SOCKET; + std::unique_ptr ws_; + time_t read_timeout_sec_ = CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = 0; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + +#ifdef CPPHTTPLIB_SSL_ENABLED + bool is_ssl_ = false; + tls::ctx_t tls_ctx_ = nullptr; + tls::session_t tls_session_ = nullptr; + std::string ca_cert_file_path_; + tls::ca_store_t ca_cert_store_ = nullptr; + bool server_certificate_verification_ = true; +#endif +}; + +namespace impl { + +bool is_valid_utf8(const std::string &s); + +bool read_websocket_frame(Stream &strm, Opcode &opcode, std::string &payload, + bool &fin, bool expect_masked, size_t max_len); + +} // namespace impl + +} // namespace ws + } // namespace httplib