diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index db1f0b23d..dd9b51a9e 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -41,9 +41,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } -static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { - auto has_min = min_value != std::numeric_limits::min(); - auto has_max = max_value != std::numeric_limits::max(); +static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { + auto has_min = min_value != std::numeric_limits::min(); + auto has_max = max_value != std::numeric_limits::max(); auto digit_range = [&](char from, char to) { out << "["; @@ -159,7 +159,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (has_min) { if (min_value < 0) { out << "\"-\" ("; - _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); + _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); out << ") | [0] | [1-9] "; more_digits(0, decimals_left - 1); } else if (min_value == 0) { @@ -194,7 +194,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & } digit_range(c, c); out << " ("; - _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); + _build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); out << ")"; if (c < '9') { out << " | "; @@ -216,7 +216,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); } else { out << "\"-\" ("; - _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); + _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); out << ")"; } return; @@ -925,17 +925,17 @@ public: int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { - int min_value = std::numeric_limits::min(); - int max_value = std::numeric_limits::max(); + int64_t min_value = std::numeric_limits::min(); + int64_t max_value = std::numeric_limits::max(); if (schema.contains("minimum")) { - min_value = schema["minimum"].get(); + min_value = schema["minimum"].get(); } else if (schema.contains("exclusiveMinimum")) { - min_value = schema["exclusiveMinimum"].get() + 1; + min_value = schema["exclusiveMinimum"].get() + 1; } if (schema.contains("maximum")) { - max_value = schema["maximum"].get(); + max_value = schema["maximum"].get(); } else if (schema.contains("exclusiveMaximum")) { - max_value = schema["exclusiveMaximum"].get() - 1; + max_value = schema["exclusiveMaximum"].get() - 1; } std::stringstream out; out << "("; diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 72eff0027..e6dca3f62 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, - size_t n_threads, size_t n_devices, - ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); + size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 54d3dece0..91fe1925e 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -485,8 +485,9 @@ template class tensor_ int32_t start = ith * task_per_thread; int32_t end = std::min((ith + 1) * task_per_thread, task_count); for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { - int32_t gemm_idx = compute_idx / block_size_m; - int32_t m_idx = compute_idx % block_size_m * block_size_m; + int32_t gemm_idx = compute_idx / per_gemm_block_count_m; + int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m; + int32_t m_idx = block_idx_in_gemm * block_size_m; const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 866cd2da5..758116342 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_TRANSPOSE_2D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 28ae2e176..4d5829748 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index bedac0463..15034f2db 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -653,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SCALE: case GGML_OP_CONV_TRANSPOSE_1D: return true; + case GGML_OP_CONV_TRANSPOSE_2D: + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && + (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32; case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SQR: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index fa2d82cef..96f43d260 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -514,6 +514,19 @@ typedef struct { uint64_t nb1; } ggml_metal_kargs_conv_transpose_1d; +typedef struct { + int32_t IC; + int32_t IH; + int32_t IW; + int32_t KH; + int32_t KW; + int32_t OC; + int32_t s0; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_conv_transpose_2d; + typedef struct { uint64_t ofs0; uint64_t ofs1; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 4f9f6bda0..7a85edbdc 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx); } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); + } break; case GGML_OP_UPSCALE: { n_fuse = ggml_metal_op_upscale(ctx, idx); @@ -3118,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + + const int32_t IC = op->src[1]->ne[2]; + const int32_t IH = op->src[1]->ne[1]; + const int32_t IW = op->src[1]->ne[0]; + + const int32_t KH = op->src[0]->ne[1]; + const int32_t KW = op->src[0]->ne[0]; + + const int32_t OW = op->ne[0]; + const int32_t OH = op->ne[1]; + const int32_t OC = op->ne[2]; + + ggml_metal_kargs_conv_transpose_2d args = { + /*.IC =*/ IC, + /*.IH =*/ IH, + /*.IW =*/ IW, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.OC =*/ OC, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + // Metal requires buffer size to be multiple of 16 bytes + const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1); + + return 1; +} + int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index f35273869..0d9cb8af7 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 496610b15..2c2f01415 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); + +typedef void (conv_transpose_2d_t)( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const float * src0, + device const float * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const T * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t out_x = tgpig[0]; + const int64_t out_y = tgpig[1]; + const int64_t out_c = tgpig[2]; + + const int64_t kw = tpitg[0]; + const int64_t kh = tpitg[1]; + + float v = 0.0f; + + for (int64_t in_c = 0; in_c < args.IC; in_c++) { + int64_t in_y = out_y - kh; + + if (in_y < 0 || in_y % args.s0) continue; + + in_y /= args.s0; + + if (in_y >= args.IH) continue; + + int64_t in_x = out_x - kw; + + if (in_x < 0 || in_x % args.s0) continue; + + in_x /= args.s0; + + if (in_x >= args.IW) continue; + + const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x; + const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw; + + v += (float)src0[kernel_idx] * src1[input_idx]; + } + + const uint tid = tpitg.y * ntg.x + tpitg.x; + shared_sum[tid] = v; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid == 0) { + float total = 0.0f; + const uint num_threads = ntg.x * ntg.y; + for (uint i = 0; i < num_threads; i++) { + total += shared_sum[i]; + } + + device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2); + dst_ptr[0] = total; + } +} + +template [[host_name("kernel_conv_transpose_2d_f32_f32")]] +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const float * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template [[host_name("kernel_conv_transpose_2d_f16_f32")]] +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const half * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + kernel void kernel_upscale_f32( constant ggml_metal_kargs_upscale & args, device const char * src0, diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl new file mode 100644 index 000000000..3917aa3fd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl @@ -0,0 +1,162 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_MXFP4 32 +#define N_SIMDGROUP 2 +#define SIMDGROUP_WIDTH 64 + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemm_moe_mxfp4_f32( + __global uint4 * src0_q, + __global uchar * src0_e, + __read_only image1d_buffer_t src1, + __global ushort4 * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int tile_size +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + ushort4 router = src2[i20]; + ushort expert_id = router.x; + ushort i11 = router.y; + ushort i1 = router.z; + ushort tile_id = router.w; + + if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size + return; + } + + uint expert_offset = expert_id * ne00 * ne01 / 32; + uint tile_offset = expert_offset + tile_id * tile_size + i01; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { + // load one block of q + uint4 regQ = src0_q[tile_offset + ib00 * ne01]; + // convert 8 fp4 to fp16 + half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); + + uint offset = i11 * ne00 / 4 + ib00 * 8; + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + uchar regE = src0_e[tile_offset + ib00 * ne01]; + sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + // if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + // if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + tile_id * tile_size + i1 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl new file mode 100644 index 000000000..b4b1e511f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl @@ -0,0 +1,156 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_MXFP4 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_mxfp4_f32( + __global uint4 * src0_q, + __global uchar * src0_e, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6); + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7); + + uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset]; + sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1c0a4019c..19689257c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -598,6 +598,9 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_ssm_scan_f32_d128; + vk_pipeline pipeline_ssm_scan_f32_d256; + vk_pipeline pipeline_ssm_conv_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; @@ -1103,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t C; uint32_t H; }; +struct vk_op_ssm_scan_push_constants { + uint32_t nb02, nb03, nb12, nb13; + uint32_t nb21, nb22, nb31; + uint32_t nb42, nb43, nb52, nb53; + uint32_t s_off; + uint32_t n_head, d_head, n_group, n_tok; +}; +struct vk_op_ssm_conv_push_constants { + uint32_t nb01, nb02; + uint32_t nb11; + uint32_t dst_nb0, dst_nb1, dst_nb2; + uint32_t nc, ncs, nr, n_t, n_s; +}; struct vk_op_conv2d_push_constants { uint32_t Cout; @@ -3607,6 +3623,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -8128,6 +8149,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; + case GGML_OP_SSM_SCAN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t d_state = src0->ne[0]; + if (d_state == 128) { + return ctx->device->pipeline_ssm_scan_f32_d128; + } else if (d_state == 256) { + return ctx->device->pipeline_ssm_scan_f32_d256; + } + } + return nullptr; + case GGML_OP_SSM_CONV: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_ssm_conv_f32; + } + return nullptr; case GGML_OP_OPT_STEP_ADAMW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_opt_step_adamw_f32; @@ -8622,6 +8658,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } break; + case GGML_OP_SSM_CONV: + { + const uint32_t nr = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + elements = { nr, n_t, n_s }; + } + break; default: elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; break; @@ -9068,6 +9112,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const ggml_tensor * src3 = dst->src[3]; + const ggml_tensor * src4 = dst->src[4]; + const ggml_tensor * src5 = dst->src[5]; + + GGML_ASSERT(dst->buffer != nullptr); + + const uint32_t head_dim = src0->ne[1]; + const uint32_t n_head = src1->ne[1]; + const uint32_t n_group = src4->ne[1]; + const uint32_t n_tok = src1->ne[2]; + const uint32_t n_seq = src1->ne[3]; + + bool is_mamba2 = (src3->nb[1] == sizeof(float)); + GGML_ASSERT(is_mamba2); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + const int64_t s_off = ggml_nelements(src1) * sizeof(float); + + const vk_op_ssm_scan_push_constants pc = { + (uint32_t)src0->nb[2], (uint32_t)src0->nb[3], + (uint32_t)src1->nb[2], (uint32_t)src1->nb[3], + (uint32_t)src2->nb[1], (uint32_t)src2->nb[2], + (uint32_t)src3->nb[1], + (uint32_t)src4->nb[2], (uint32_t)src4->nb[3], + (uint32_t)src5->nb[2], (uint32_t)src5->nb[3], + (uint32_t)s_off, + n_head, head_dim, n_group, n_tok + }; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC]; + for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { + src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; + } + + vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr }; + size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 }; + bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false }; + + if (ctx->device->uma) { + for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { + ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); + srcs_uma[i] = d_srcs[i] != nullptr; + } + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + dst_uma = d_D != nullptr; + } + + if (!dst_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { + if (!srcs_uma[i]) { + d_srcs[i] = src_buf_ctxs[i]->dev_buffer; + src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; + } + } + + size_t dst_size = ggml_nbytes(dst); + size_t src_sizes[GGML_MAX_SRC]; + for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { + src_sizes[i] = ggml_nbytes(dst->src[i]); + } + + std::array elements; + + const int splitH = 16; + const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH); + const uint32_t num_workgroups_y = n_seq; + elements = { num_workgroups_x, num_workgroups_y, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, pc, elements); +} + +static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, { + (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], + (uint32_t)src1->nb[1], + (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], + (uint32_t)src1->ne[0], + (uint32_t)src0->ne[0], + (uint32_t)src0->ne[1], + (uint32_t)dst->ne[1], + (uint32_t)dst->ne[2], + }, dryrun); +} + static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { const ggml_tensor * x = dst->src[0]; const ggml_tensor * g = dst->src[1]; @@ -10900,6 +11055,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: + case GGML_OP_SSM_SCAN: + case GGML_OP_SSM_CONV: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_OPT_STEP_ADAMW: @@ -11317,6 +11474,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; + case GGML_OP_SSM_SCAN: + ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_SSM_CONV: + ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun); + + break; + case GGML_OP_OPT_STEP_ADAMW: ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); @@ -11428,6 +11595,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: + case GGML_OP_SSM_SCAN: + case GGML_OP_SSM_CONV: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: @@ -12909,6 +13078,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_SSM_SCAN: + { + for (int i = 0; i < 6; i++) { + if (op->src[i] && ggml_is_quantized(op->src[i]->type)) { + return false; + } + } + if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) { + return false; + } + if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) { + return false; + } + + const uint32_t d_state = op->src[0]->ne[0]; + const uint32_t head_dim = op->src[0]->ne[1]; + + bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float)); + if (!is_mamba2) { + return false; + } + + if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) { + return false; + } + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + const uint32_t SPLIT_H = 16; + + size_t stateC_size = SPLIT_H * d_state * sizeof(float); + + if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } + + return true; + } + case GGML_OP_SSM_CONV: + return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -13253,14 +13463,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * struct ggml_context * ggml_ctx = ggml_init(iparams); - std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; - std::array src_size = {0, 0, 0, 0, 0, 0}; - std::array src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; - const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; + std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + std::array src_size = {}; + std::array src_buffer = {}; + const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"}; struct ggml_tensor * tensor_clone = nullptr; - for (int i = 0; i < 6; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { ggml_tensor * srci = tensor->src[i]; if (fused_rms_norm_mul) { rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; @@ -13567,6 +13777,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[2]); } else if (tensor->op == GGML_OP_ADD_ID) { tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SSM_SCAN) { + tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], + src_clone[3], src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_SSM_CONV) { + tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]); } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -13588,7 +13803,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * memcpy(comp_result, tensor_clone->data, comp_size); memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); - for (int i = 0; i < 6; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { if (src_buffer[i] != nullptr) { free(src_buffer[i]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp new file mode 100644 index 000000000..d62696bcf --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -0,0 +1,44 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer Src0 { float src0[]; }; +layout(binding = 1) readonly buffer Src1 { float src1[]; }; +layout(binding = 2) buffer Dst { float dst[]; }; + +layout(push_constant) uniform PushConstants { + uint nb01; uint nb02; + uint nb11; + uint dst_nb0; uint dst_nb1; uint dst_nb2; + uint nc; uint ncs; uint nr; uint n_t; uint n_s; +}; + +void main() { + const uint global_thread_id = gl_GlobalInvocationID.x; + const uint i2 = gl_WorkGroupID.y; + const uint i3 = gl_WorkGroupID.z; + + if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) { + return; + } + + const uint i1 = global_thread_id; + const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4); + const uint src1_base = i1 * (nb11 / 4); + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; + + float sum = 0.0; + [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { + const uint src0_idx = src0_base + i0; + const uint src1_idx = src1_base + i0; + sum += src0[src0_idx] * src1[src1_idx]; + } + + dst[dst_idx] = sum; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp new file mode 100644 index 000000000..12bd17457 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -0,0 +1,125 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.glsl" + +layout(constant_id = 0) const uint D_STATE = 128; +layout(constant_id = 1) const uint SUBGROUP_SIZE = 32; +layout(constant_id = 2) const uint SPLIT_H = 16; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer Src0 { float s0[]; }; +layout(binding = 1) readonly buffer Src1 { float x[]; }; +layout(binding = 2) readonly buffer Src2 { float dt[]; }; +layout(binding = 3) readonly buffer Src3 { float A[]; }; +layout(binding = 4) readonly buffer Src4 { float B[]; }; +layout(binding = 5) readonly buffer Src5 { float C[]; }; +layout(binding = 6) readonly buffer Src6 { int ids[]; }; +layout(binding = 7) buffer Dst { float d[]; }; + +layout(push_constant) uniform PushConstants { + uint nb02; uint nb03; uint nb12; uint nb13; + uint nb21; uint nb22; uint nb31; + uint nb42; uint nb43; uint nb52; uint nb53; + uint s_off; + uint n_head; + uint d_head; + uint n_group; + uint n_tok; +}; + +float softplus(float x) { + if (x <= 20.0) { + return log(1.0 + exp(x)); + } else { + return x; + } +} + +shared float stateC[SPLIT_H * D_STATE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head; + const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4; + const uint seq_idx = gl_WorkGroupID.y; + + const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4; + const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; + const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4; + const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4; + const uint A_base_idx = (head_idx * nb31) / 4; + const uint B_base_idx = (seq_idx * nb43 + group_off) / 4; + const uint C_base_idx = (seq_idx * nb53 + group_off) / 4; + const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H; + const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; + + const uint stride_x = nb12 / 4; + const uint stride_dt = nb21 / 4; + const uint stride_B = nb42 / 4; + const uint stride_C = nb52 / 4; + const uint stride_y = n_head * d_head; + + float state[SPLIT_H]; + [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { + state[j] = s0[s0_base_idx + j * D_STATE + tid]; + } + + for (uint i = 0; i < n_tok; i++) { + const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); + + const float dA = exp(dt_soft_plus * A[A_base_idx]); + + const float B_val = B[B_base_idx + i * stride_B + tid]; + const float C_val = C[C_base_idx + i * stride_C + tid]; + + [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { + const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus; + + state[j] = (state[j] * dA) + (B_val * x_dt); + + stateC[j * D_STATE + tid] = state[j] * C_val; + } + + barrier(); + for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) { + [[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) { + const uint k = (tid % (w >> 1)) + + (D_STATE * (tid / (w >> 1))) + + j * D_STATE * (D_STATE / (w >> 1)); + if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) { + stateC[k] += stateC[k + (w >> 1)]; + } + } + barrier(); + } + + [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) { + const uint idx = (tid % SUBGROUP_SIZE) + + D_STATE * (tid / SUBGROUP_SIZE) + + j * D_STATE * (D_STATE / SUBGROUP_SIZE); + + uint lane = tid % SUBGROUP_SIZE; + + [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { + if (idx + offset < SPLIT_H * D_STATE) { + stateC[idx] += stateC[idx + offset]; + } + barrier(); + } + + if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) { + const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); + d[y_base_idx + i * stride_y + k] = stateC[idx]; + } + } + + barrier(); + } + + [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { + d[s_base_idx + j * D_STATE + tid] = state[j]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a1dc3f4c3..ee33e3903 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -934,6 +934,10 @@ void process_shaders() { string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); + + string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}}); + for (auto &c : compiles) { c.wait(); } @@ -977,7 +981,7 @@ void write_output_files() { } std::string suffixes[2] = {"_f32", "_f16"}; - for (auto op : {"add", "sub", "mul", "div", "add_rms"}) { + for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) { hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4876021b1..e2fd7c096 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -426,11 +426,8 @@ struct llama_model::impl { llama_mlocks mlock_bufs; llama_mlocks mlock_mmaps; - // contexts where the model tensors metadata is stored - std::vector ctxs; - - // the model memory buffers for the tensor data - std::vector bufs; + // contexts where the model tensors metadata is stored as well ass the corresponding buffers: + std::vector> ctxs_bufs; buft_list_t cpu_buft_list; std::map gpu_buft_list; @@ -2194,7 +2191,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { max_n_tensors += n_layer*2; // duplicated rope freq tensors const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; - std::map ctx_map; + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs); + } + }; + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -2209,12 +2213,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { throw std::runtime_error(format("failed to create ggml context")); } - ctx_map[buft] = ctx; - pimpl->ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; @@ -6096,16 +6099,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { pimpl->mappings.reserve(ml.mappings.size()); // create the backend buffers - std::vector> ctx_bufs; - ctx_bufs.reserve(ctx_map.size()); + std::vector> ctx_buf_maps; + ctx_buf_maps.reserve(ctx_map.size()); // Ensure we have enough capacity for the maximum backend buffer we will potentially create const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); - pimpl->bufs.reserve(n_max_backend_buffer); + pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - for (auto & it : ctx_map) { - ggml_backend_buffer_type_t buft = it.first; - ggml_context * ctx = it.second; + for (auto & [buft, ctx_ptr] : ctx_map) { + ggml_context * ctx = ctx_ptr.get(); // skip contexts without tensors if (ggml_get_first_tensor(ctx) == nullptr) { @@ -6129,6 +6131,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); + ggml_backend_buffer_t buf = nullptr; if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) { // only the mmap region containing the tensors in the model is mapped to the backend buffer @@ -6141,20 +6144,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) { continue; } const size_t max_size = ggml_get_max_tensor_size(ctx); - ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } - pimpl->bufs.emplace_back(buf); buf_map.emplace(idx, buf); } } else { - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } - pimpl->bufs.emplace_back(buf); if (use_mlock && ggml_backend_buffer_is_host(buf)) { pimpl->mlock_bufs.emplace_back(new llama_mlock); auto & mlock_buf = pimpl->mlock_bufs.back(); @@ -6165,10 +6166,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { buf_map.emplace(idx, buf); } } - - if (pimpl->bufs.empty()) { - throw std::runtime_error("failed to allocate buffer"); - } + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf); for (auto & buf : buf_map) { // indicate that this buffer contains weights @@ -6176,7 +6174,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); } - ctx_bufs.emplace_back(ctx, buf_map); + ctx_buf_maps.emplace_back(ctx, buf_map); } if (llama_supports_gpu_offload()) { @@ -6194,22 +6192,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // print memory requirements per buffer type - for (auto & buf : pimpl->bufs) { + for (auto & [_, buf] : pimpl->ctxs_bufs) { LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); } // populate tensors_by_name - for (auto & ctx : pimpl->ctxs) { + for (auto & [ctx, _] : pimpl->ctxs_bufs) { for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { tensors_by_name.emplace_back(ggml_get_name(cur), cur); } } // load tensor data - for (auto & it : ctx_bufs) { - ggml_context * ctx = it.first; - auto & bufs = it.second; - if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { + for (auto & [ctx, buf_map] : ctx_buf_maps) { + if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { return false; } } @@ -6249,8 +6245,8 @@ size_t llama_model::n_devices() const { std::map llama_model::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : pimpl->ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 34818012c..c76f5778b 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte index d5d4c7fe3..bf1763309 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte @@ -4,7 +4,7 @@ Funnel, AlertTriangle, Brain, - Cog, + Code, Monitor, Sun, Moon, @@ -88,9 +88,59 @@ ] }, { - title: 'Samplers', + title: 'Sampling', icon: Funnel, fields: [ + { + key: 'temperature', + label: 'Temperature', + type: 'input' + }, + { + key: 'dynatemp_range', + label: 'Dynamic temperature range', + type: 'input' + }, + { + key: 'dynatemp_exponent', + label: 'Dynamic temperature exponent', + type: 'input' + }, + { + key: 'top_k', + label: 'Top K', + type: 'input' + }, + { + key: 'top_p', + label: 'Top P', + type: 'input' + }, + { + key: 'min_p', + label: 'Min P', + type: 'input' + }, + { + key: 'xtc_probability', + label: 'XTC probability', + type: 'input' + }, + { + key: 'xtc_threshold', + label: 'XTC threshold', + type: 'input' + }, + { + key: 'typ_p', + label: 'Typical P', + type: 'input' + }, + { + key: 'max_tokens', + label: 'Max tokens', + type: 'input' + }, { key: 'samplers', label: 'Samplers', @@ -152,68 +202,17 @@ key: 'showThoughtInProgress', label: 'Show thought in progress', type: 'checkbox' - }, - { - key: 'disableReasoningFormat', - label: - 'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.', - type: 'checkbox' } ] }, { - title: 'Advanced', - icon: Cog, + title: 'Developer', + icon: Code, fields: [ { - key: 'temperature', - label: 'Temperature', - type: 'input' - }, - { - key: 'dynatemp_range', - label: 'Dynamic temperature range', - type: 'input' - }, - { - key: 'dynatemp_exponent', - label: 'Dynamic temperature exponent', - type: 'input' - }, - { - key: 'top_k', - label: 'Top K', - type: 'input' - }, - { - key: 'top_p', - label: 'Top P', - type: 'input' - }, - { - key: 'min_p', - label: 'Min P', - type: 'input' - }, - { - key: 'xtc_probability', - label: 'XTC probability', - type: 'input' - }, - { - key: 'xtc_threshold', - label: 'XTC threshold', - type: 'input' - }, - { - key: 'typ_p', - label: 'Typical P', - type: 'input' - }, - { - key: 'max_tokens', - label: 'Max tokens', - type: 'input' + key: 'disableReasoningFormat', + label: 'Show raw LLM output', + type: 'checkbox' }, { key: 'custom',