mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
Merge commit 'ddf9f94389' into concedo_experimental
# Conflicts: # examples/model-conversion/scripts/causal/run-converted-model.sh # examples/model-conversion/scripts/causal/run-org-model.py # src/CMakeLists.txt # src/llama-quant.cpp # tools/server/README.md
This commit is contained in:
commit
0ccb298087
23 changed files with 2813 additions and 86 deletions
|
|
@ -4183,6 +4183,36 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
|||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3NextForCausalLM")
|
||||
class Qwen3NextModel(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
|
||||
self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
|
||||
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
|
||||
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
|
||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("mtp"):
|
||||
return [] # ignore MTP layers for now
|
||||
if name.endswith(".A_log"):
|
||||
data_torch = -torch.exp(data_torch)
|
||||
elif name.endswith(".dt_bias"):
|
||||
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
|
||||
elif "conv1d" in name:
|
||||
data_torch = data_torch.squeeze()
|
||||
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
|
||||
data_torch = data_torch + 1
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("RND1")
|
||||
class RND1Model(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.RND1
|
||||
|
|
|
|||
|
|
@ -9766,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
|
|||
}
|
||||
|
||||
const float diag = A_batch[i00 * n + i00];
|
||||
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
|
||||
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -366,6 +366,7 @@ class MODEL_ARCH(IntEnum):
|
|||
QWEN2VL = auto()
|
||||
QWEN3 = auto()
|
||||
QWEN3MOE = auto()
|
||||
QWEN3NEXT = auto()
|
||||
QWEN3VL = auto()
|
||||
QWEN3VLMOE = auto()
|
||||
PHI2 = auto()
|
||||
|
|
@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
SSM_D = auto()
|
||||
SSM_NORM = auto()
|
||||
SSM_OUT = auto()
|
||||
SSM_BETA_ALPHA = auto() # qwen3next
|
||||
TIME_MIX_W0 = auto()
|
||||
TIME_MIX_W1 = auto()
|
||||
TIME_MIX_W2 = auto()
|
||||
|
|
@ -736,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.QWEN3: "qwen3",
|
||||
MODEL_ARCH.QWEN3MOE: "qwen3moe",
|
||||
MODEL_ARCH.QWEN3NEXT: "qwen3next",
|
||||
MODEL_ARCH.QWEN3VL: "qwen3vl",
|
||||
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
|
|
@ -900,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
|
||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
|
||||
|
|
@ -1569,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.QWEN3NEXT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.ATTN_GATE,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_BETA_ALPHA,
|
||||
MODEL_TENSOR.SSM_OUT
|
||||
],
|
||||
MODEL_ARCH.QWEN3VL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -672,10 +672,11 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
||||
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
|
||||
"model.layers.{bid}.in_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
||||
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
|
||||
"model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_CONV1D: (
|
||||
|
|
@ -683,6 +684,7 @@ class TensorNameMap:
|
|||
"backbone.layers.{bid}.mixer.conv1d", # mamba
|
||||
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
|
||||
"model.layers.{bid}.linear_attn.conv1d", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_X: (
|
||||
|
|
@ -697,6 +699,7 @@ class TensorNameMap:
|
|||
"backbone.layers.{bid}.mixer.dt_proj", # mamba
|
||||
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
|
||||
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_DT_NORM: (
|
||||
|
|
@ -709,6 +712,7 @@ class TensorNameMap:
|
|||
"backbone.layers.{bid}.mixer.A_log", # mamba
|
||||
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.A_log", # plamo2
|
||||
"model.layers.{bid}.linear_attn.A_log", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_B_NORM: (
|
||||
|
|
@ -731,17 +735,23 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.SSM_NORM: (
|
||||
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
|
||||
"model.layers.{bid}.linear_attn.norm", # qwen3next
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
"model.layers.{bid}.out_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.out_proj", # mamba
|
||||
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.{bid}.linear_attn.out_proj", # qwen3next
|
||||
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_BETA_ALPHA: (
|
||||
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
"model.layers.{bid}.attention.w0", # rwkv7
|
||||
),
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||
{ LLM_ARCH_QWEN3, "qwen3" },
|
||||
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
||||
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
|
||||
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
|
||||
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
|
|
@ -829,6 +830,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN3NEXT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN3VL,
|
||||
{
|
||||
|
|
@ -2556,6 +2589,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
|
@ -2754,6 +2788,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
|||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ enum llm_arch {
|
|||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_QWEN3,
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
LLM_ARCH_QWEN3NEXT,
|
||||
LLM_ARCH_QWEN3VL,
|
||||
LLM_ARCH_QWEN3VLMOE,
|
||||
LLM_ARCH_PHI2,
|
||||
|
|
@ -381,6 +382,7 @@ enum llm_tensor {
|
|||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_NORM,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
|
||||
LLM_TENSOR_TIME_MIX_W0,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
LLM_TENSOR_TIME_MIX_W2,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include "llama-context.h"
|
||||
|
||||
#include "llama-arch.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-io.h"
|
||||
|
|
@ -1388,6 +1389,9 @@ void llama_context::output_reorder() {
|
|||
//
|
||||
|
||||
uint32_t llama_context::graph_max_nodes() const {
|
||||
if (model.arch == LLM_ARCH_QWEN3NEXT) {
|
||||
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
|
||||
}
|
||||
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
|
||||
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
|
||||
|
||||
enum llama_expert_gating_func_type {
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-model-loader.h"
|
||||
|
||||
|
|
@ -106,6 +105,7 @@
|
|||
#include "models/qwen3vl.cpp"
|
||||
#include "models/qwen3vl-moe.cpp"
|
||||
#include "models/qwen3moe.cpp"
|
||||
#include "models/qwen3next.cpp"
|
||||
#include "models/refact.cpp"
|
||||
#include "models/rnd1.cpp"
|
||||
#include "models/rwkv6-base.cpp"
|
||||
|
|
@ -2328,6 +2328,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
// Load linear attention (gated delta net) parameters
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||
|
||||
// Mark recurrent layers (linear attention layers)
|
||||
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval"
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 80: type = LLM_TYPE_80B_A3B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
|
@ -6571,6 +6594,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
// Calculate dimensions from hyperparameters
|
||||
const int64_t head_k_dim = hparams.ssm_d_state;
|
||||
const int64_t head_v_dim = hparams.ssm_d_state;
|
||||
const int64_t n_k_heads = hparams.ssm_n_group;
|
||||
const int64_t n_v_heads = hparams.ssm_dt_rank;
|
||||
const int64_t key_dim = head_k_dim * n_k_heads;
|
||||
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||
|
||||
// Calculate projection sizes
|
||||
const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
|
||||
const int64_t ba_dim = n_v_heads * 2;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
if (!hparams.is_recurrent(i)) {
|
||||
// Attention layers
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
|
||||
// Q/K normalization for attention layers
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
} else {
|
||||
// Linear attention (gated delta net) specific tensors
|
||||
// Create tensors with calculated dimensions
|
||||
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0);
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
|
||||
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
|
||||
}
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||
|
||||
// Shared experts
|
||||
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
|
@ -6842,6 +6933,7 @@ void llama_model::print_info() const {
|
|||
arch == LLM_ARCH_FALCON_H1 ||
|
||||
arch == LLM_ARCH_PLAMO2 ||
|
||||
arch == LLM_ARCH_GRANITE_HYBRID ||
|
||||
arch == LLM_ARCH_QWEN3NEXT ||
|
||||
arch == LLM_ARCH_NEMOTRON_H) {
|
||||
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||
|
|
@ -7586,7 +7678,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
case LLM_ARCH_PANGU_EMBED:
|
||||
{
|
||||
llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
|
||||
}break;
|
||||
} break;
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
{
|
||||
llm = std::make_unique<llm_build_qwen3next>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -7813,6 +7909,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_COGVLM:
|
||||
case LLM_ARCH_PANGU_EMBED:
|
||||
case LLM_ARCH_AFMOE:
|
||||
case LLM_ARCH_QWEN3NEXT:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ enum llm_type {
|
|||
LLM_TYPE_16B_A1B,
|
||||
LLM_TYPE_21B_A3B, // Ernie MoE small
|
||||
LLM_TYPE_30B_A3B,
|
||||
LLM_TYPE_80B_A3B, // Qwen3 Next
|
||||
LLM_TYPE_100B_A6B,
|
||||
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||
LLM_TYPE_230B_A10B, // Minimax M2
|
||||
|
|
@ -309,6 +310,9 @@ struct llama_layer {
|
|||
struct ggml_tensor * ssm_conv1d_b = nullptr;
|
||||
struct ggml_tensor * ssm_dt_b = nullptr;
|
||||
|
||||
// qwen3next
|
||||
struct ggml_tensor * ssm_beta_alpha = nullptr;
|
||||
|
||||
// rwkv
|
||||
struct ggml_tensor * time_mix_w1 = nullptr;
|
||||
struct ggml_tensor * time_mix_w2 = nullptr;
|
||||
|
|
|
|||
|
|
@ -684,7 +684,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
}
|
||||
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||
continue;
|
||||
} else if (remapped_name != it.first) {
|
||||
}
|
||||
|
||||
if (remapped_name != it.first) {
|
||||
ggml_set_name(it.second.tensor, remapped_name.c_str());
|
||||
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
|
||||
}
|
||||
|
|
@ -729,13 +731,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||
int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
// now n_attn_layer is the number of attention layers in the encoder
|
||||
// now n_layer_attn is the number of attention layers in the encoder
|
||||
// for each decoder block, there are 2 attention layers
|
||||
n_attn_layer += 2 * model.hparams.dec_n_layer;
|
||||
n_layer_attn += 2 * model.hparams.dec_n_layer;
|
||||
}
|
||||
GGML_ASSERT_CONTINUE((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
|
||||
|
||||
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
||||
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
|
||||
|
||||
GGML_ASSERT_CONTINUE((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@
|
|||
|
||||
#include "../llama-model.h"
|
||||
#include "../llama-graph.h"
|
||||
#include "../llama-memory-recurrent.h"
|
||||
|
||||
// TODO: remove in follow-up PR - move to .cpp files
|
||||
#include "../llama-memory-recurrent.h"
|
||||
#include <cmath>
|
||||
|
||||
struct llm_graph_context_mamba : public llm_graph_context {
|
||||
|
|
@ -421,7 +422,56 @@ struct llm_build_qwen3vl : public llm_graph_context {
|
|||
struct llm_build_qwen3vlmoe : public llm_graph_context {
|
||||
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
struct llm_build_qwen3next : public llm_graph_context_mamba {
|
||||
llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_layer_attn(
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
ggml_tensor * cur,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_delta_net_recurrent(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_norm_gated(
|
||||
ggml_tensor * input,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * gate,
|
||||
int layer);
|
||||
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
struct llm_build_qwen : public llm_graph_context {
|
||||
llm_build_qwen(const llama_model & model, const llm_graph_params & params);
|
||||
|
|
|
|||
1042
src/models/qwen3next.cpp
Normal file
1042
src/models/qwen3next.cpp
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -725,7 +725,6 @@ std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtm
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// OAI utils
|
||||
//
|
||||
|
|
@ -1048,6 +1047,222 @@ json oaicompat_chat_params_parse(
|
|||
return llama_params;
|
||||
}
|
||||
|
||||
json convert_anthropic_to_oai(const json & body) {
|
||||
json oai_body;
|
||||
|
||||
// Convert system prompt
|
||||
json oai_messages = json::array();
|
||||
auto system_param = json_value(body, "system", json());
|
||||
if (!system_param.is_null()) {
|
||||
std::string system_content;
|
||||
|
||||
if (system_param.is_string()) {
|
||||
system_content = system_param.get<std::string>();
|
||||
} else if (system_param.is_array()) {
|
||||
for (const auto & block : system_param) {
|
||||
if (json_value(block, "type", std::string()) == "text") {
|
||||
system_content += json_value(block, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", system_content}
|
||||
});
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
if (!body.contains("messages")) {
|
||||
throw std::runtime_error("'messages' is required");
|
||||
}
|
||||
const json & messages = body.at("messages");
|
||||
if (messages.is_array()) {
|
||||
for (const auto & msg : messages) {
|
||||
std::string role = json_value(msg, "role", std::string());
|
||||
|
||||
if (!msg.contains("content")) {
|
||||
if (role == "assistant") {
|
||||
continue;
|
||||
}
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
const json & content = msg.at("content");
|
||||
|
||||
if (content.is_string()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!content.is_array()) {
|
||||
oai_messages.push_back(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
json tool_calls = json::array();
|
||||
json converted_content = json::array();
|
||||
json tool_results = json::array();
|
||||
bool has_tool_calls = false;
|
||||
|
||||
for (const auto & block : content) {
|
||||
std::string type = json_value(block, "type", std::string());
|
||||
|
||||
if (type == "text") {
|
||||
converted_content.push_back(block);
|
||||
} else if (type == "image") {
|
||||
json source = json_value(block, "source", json::object());
|
||||
std::string source_type = json_value(source, "type", std::string());
|
||||
|
||||
if (source_type == "base64") {
|
||||
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
|
||||
std::string data = json_value(source, "data", std::string());
|
||||
std::ostringstream ss;
|
||||
ss << "data:" << media_type << ";base64," << data;
|
||||
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", ss.str()}
|
||||
}}
|
||||
});
|
||||
} else if (source_type == "url") {
|
||||
std::string url = json_value(source, "url", std::string());
|
||||
converted_content.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", url}
|
||||
}}
|
||||
});
|
||||
}
|
||||
} else if (type == "tool_use") {
|
||||
tool_calls.push_back({
|
||||
{"id", json_value(block, "id", std::string())},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(block, "name", std::string())},
|
||||
{"arguments", json_value(block, "input", json::object()).dump()}
|
||||
}}
|
||||
});
|
||||
has_tool_calls = true;
|
||||
} else if (type == "tool_result") {
|
||||
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
|
||||
|
||||
auto result_content = json_value(block, "content", json());
|
||||
std::string result_text;
|
||||
if (result_content.is_string()) {
|
||||
result_text = result_content.get<std::string>();
|
||||
} else if (result_content.is_array()) {
|
||||
for (const auto & c : result_content) {
|
||||
if (json_value(c, "type", std::string()) == "text") {
|
||||
result_text += json_value(c, "text", std::string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_results.push_back({
|
||||
{"role", "tool"},
|
||||
{"tool_call_id", tool_use_id},
|
||||
{"content", result_text}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!converted_content.empty() || has_tool_calls) {
|
||||
json new_msg = {{"role", role}};
|
||||
if (!converted_content.empty()) {
|
||||
new_msg["content"] = converted_content;
|
||||
} else if (has_tool_calls) {
|
||||
new_msg["content"] = "";
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
new_msg["tool_calls"] = tool_calls;
|
||||
}
|
||||
oai_messages.push_back(new_msg);
|
||||
}
|
||||
|
||||
for (const auto & tool_msg : tool_results) {
|
||||
oai_messages.push_back(tool_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
oai_body["messages"] = oai_messages;
|
||||
|
||||
// Convert tools
|
||||
if (body.contains("tools")) {
|
||||
const json & tools = body.at("tools");
|
||||
if (tools.is_array()) {
|
||||
json oai_tools = json::array();
|
||||
for (const auto & tool : tools) {
|
||||
oai_tools.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", json_value(tool, "name", std::string())},
|
||||
{"description", json_value(tool, "description", std::string())},
|
||||
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
|
||||
}}
|
||||
});
|
||||
}
|
||||
oai_body["tools"] = oai_tools;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool_choice
|
||||
if (body.contains("tool_choice")) {
|
||||
const json & tc = body.at("tool_choice");
|
||||
if (tc.is_object()) {
|
||||
std::string type = json_value(tc, "type", std::string());
|
||||
if (type == "auto") {
|
||||
oai_body["tool_choice"] = "auto";
|
||||
} else if (type == "any" || type == "tool") {
|
||||
oai_body["tool_choice"] = "required";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert stop_sequences to stop
|
||||
if (body.contains("stop_sequences")) {
|
||||
oai_body["stop"] = body.at("stop_sequences");
|
||||
}
|
||||
|
||||
// Handle max_tokens (required in Anthropic, but we're permissive)
|
||||
if (body.contains("max_tokens")) {
|
||||
oai_body["max_tokens"] = body.at("max_tokens");
|
||||
} else {
|
||||
oai_body["max_tokens"] = 4096;
|
||||
}
|
||||
|
||||
// Pass through common params
|
||||
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
|
||||
if (body.contains(key)) {
|
||||
oai_body[key] = body.at(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific thinking param
|
||||
if (body.contains("thinking")) {
|
||||
json thinking = json_value(body, "thinking", json::object());
|
||||
std::string thinking_type = json_value(thinking, "type", std::string());
|
||||
if (thinking_type == "enabled") {
|
||||
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
|
||||
oai_body["thinking_budget_tokens"] = budget_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Anthropic-specific metadata param
|
||||
if (body.contains("metadata")) {
|
||||
json metadata = json_value(body, "metadata", json::object());
|
||||
std::string user_id = json_value(metadata, "user_id", std::string());
|
||||
if (!user_id.empty()) {
|
||||
oai_body["__metadata_user_id"] = user_id;
|
||||
}
|
||||
}
|
||||
|
||||
return oai_body;
|
||||
}
|
||||
|
||||
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
|
||||
json data = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
|
|
@ -1211,7 +1426,7 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
|
|||
|
||||
// format server-sent event (SSE), return the formatted string to send
|
||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||
std::string format_sse(const json & data) {
|
||||
std::string format_oai_sse(const json & data) {
|
||||
std::ostringstream ss;
|
||||
auto send_single = [&ss](const json & data) {
|
||||
ss << "data: " <<
|
||||
|
|
@ -1230,6 +1445,29 @@ std::string format_sse(const json & data) {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
std::string format_anthropic_sse(const json & data) {
|
||||
std::ostringstream ss;
|
||||
|
||||
auto send_event = [&ss](const json & event_obj) {
|
||||
if (event_obj.contains("event") && event_obj.contains("data")) {
|
||||
ss << "event: " << event_obj.at("event").get<std::string>() << "\n";
|
||||
ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n";
|
||||
} else {
|
||||
ss << "data: " << safe_json_to_str(event_obj) << "\n\n";
|
||||
}
|
||||
};
|
||||
|
||||
if (data.is_array()) {
|
||||
for (const auto & event : data) {
|
||||
send_event(event);
|
||||
}
|
||||
} else {
|
||||
send_event(data);
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool is_valid_utf8(const std::string & str) {
|
||||
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
|
||||
const unsigned char* end = bytes + str.length();
|
||||
|
|
|
|||
|
|
@ -294,6 +294,9 @@ json oaicompat_chat_params_parse(
|
|||
const oaicompat_parser_options & opt,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
// convert Anthropic Messages API format to OpenAI Chat Completions API format
|
||||
json convert_anthropic_to_oai(const json & body);
|
||||
|
||||
// TODO: move it to server-task.cpp
|
||||
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
|
||||
|
||||
|
|
@ -320,7 +323,10 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
|
|||
|
||||
// format server-sent event (SSE), return the formatted string to send
|
||||
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||
std::string format_sse(const json & data);
|
||||
std::string format_oai_sse(const json & data);
|
||||
|
||||
// format Anthropic-style SSE with event types
|
||||
std::string format_anthropic_sse(const json & data);
|
||||
|
||||
bool is_valid_utf8(const std::string & str);
|
||||
|
||||
|
|
|
|||
|
|
@ -136,15 +136,22 @@ bool server_http_context::init(const common_params & params) {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Check for API key in the header
|
||||
auto auth_header = req.get_header_value("Authorization");
|
||||
// Check for API key in the Authorization header
|
||||
std::string req_api_key = req.get_header_value("Authorization");
|
||||
if (req_api_key.empty()) {
|
||||
// retry with anthropic header
|
||||
req_api_key = req.get_header_value("X-Api-Key");
|
||||
}
|
||||
|
||||
// remove the "Bearer " prefix if needed
|
||||
std::string prefix = "Bearer ";
|
||||
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||
std::string received_api_key = auth_header.substr(prefix.size());
|
||||
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
if (req_api_key.substr(0, prefix.size()) == prefix) {
|
||||
req_api_key = req_api_key.substr(prefix.size());
|
||||
}
|
||||
|
||||
// validate the API key
|
||||
if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
|
||||
return true; // API key is valid
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
|
|
|
|||
|
|
@ -565,15 +565,17 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
|
|||
// server_task_result_cmpl_final
|
||||
//
|
||||
json server_task_result_cmpl_final::to_json() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
switch (res_type) {
|
||||
case TASK_RESPONSE_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
case TASK_RESPONSE_TYPE_OAI_CMPL:
|
||||
return to_json_oaicompat();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
case TASK_RESPONSE_TYPE_OAI_CHAT:
|
||||
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
|
||||
case TASK_RESPONSE_TYPE_ANTHROPIC:
|
||||
return stream ? to_json_anthropic_stream() : to_json_anthropic();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
GGML_ASSERT(false && "Invalid task_response_type");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -768,19 +770,203 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
|||
return deltas;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_anthropic() {
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
json content_blocks = json::array();
|
||||
|
||||
common_chat_msg msg;
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
} else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
if (!msg.content.empty()) {
|
||||
content_blocks.push_back({
|
||||
{"type", "text"},
|
||||
{"text", msg.content}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto & tool_call : msg.tool_calls) {
|
||||
json tool_use_block = {
|
||||
{"type", "tool_use"},
|
||||
{"id", tool_call.id},
|
||||
{"name", tool_call.name}
|
||||
};
|
||||
|
||||
try {
|
||||
tool_use_block["input"] = json::parse(tool_call.arguments);
|
||||
} catch (const std::exception &) {
|
||||
tool_use_block["input"] = json::object();
|
||||
}
|
||||
|
||||
content_blocks.push_back(tool_use_block);
|
||||
}
|
||||
|
||||
json res = {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", content_blocks},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
};
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
||||
json events = json::array();
|
||||
|
||||
std::string stop_reason = "max_tokens";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
|
||||
}
|
||||
|
||||
bool has_text = !oaicompat_msg.content.empty();
|
||||
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
|
||||
|
||||
bool text_block_started = false;
|
||||
std::unordered_set<size_t> tool_calls_started;
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
|
||||
const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", full_tool_call.id},
|
||||
{"name", full_tool_call.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
tool_calls_started.insert(diff.tool_call_index);
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_text) {
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", 0}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_tool_calls; i++) {
|
||||
size_t content_block_index = (has_text ? 1 : 0) + i;
|
||||
events.push_back({
|
||||
{"event", "content_block_stop"},
|
||||
{"data", {
|
||||
{"type", "content_block_stop"},
|
||||
{"index", content_block_index}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_delta"},
|
||||
{"data", {
|
||||
{"type", "message_delta"},
|
||||
{"delta", {
|
||||
{"stop_reason", stop_reason},
|
||||
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}
|
||||
}},
|
||||
{"usage", {
|
||||
{"output_tokens", n_decoded}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_stop"},
|
||||
{"data", {
|
||||
{"type", "message_stop"}
|
||||
}}
|
||||
});
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
//
|
||||
// server_task_result_cmpl_partial
|
||||
//
|
||||
json server_task_result_cmpl_partial::to_json() {
|
||||
switch (oaicompat) {
|
||||
case OAICOMPAT_TYPE_NONE:
|
||||
switch (res_type) {
|
||||
case TASK_RESPONSE_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
case OAICOMPAT_TYPE_COMPLETION:
|
||||
case TASK_RESPONSE_TYPE_OAI_CMPL:
|
||||
return to_json_oaicompat();
|
||||
case OAICOMPAT_TYPE_CHAT:
|
||||
case TASK_RESPONSE_TYPE_OAI_CHAT:
|
||||
return to_json_oaicompat_chat();
|
||||
case TASK_RESPONSE_TYPE_ANTHROPIC:
|
||||
return to_json_anthropic();
|
||||
default:
|
||||
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||
GGML_ASSERT(false && "Invalid task_response_type");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -905,7 +1091,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
|||
// server_task_result_embd
|
||||
//
|
||||
json server_task_result_embd::to_json() {
|
||||
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
return res_type == TASK_RESPONSE_TYPE_OAI_EMBD
|
||||
? to_json_oaicompat()
|
||||
: to_json_non_oaicompat();
|
||||
}
|
||||
|
|
@ -936,6 +1122,102 @@ json server_task_result_rerank::to_json() {
|
|||
};
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_partial::to_json_anthropic() {
|
||||
json events = json::array();
|
||||
bool first = (n_decoded == 1);
|
||||
static bool text_block_started = false;
|
||||
|
||||
if (first) {
|
||||
text_block_started = false;
|
||||
|
||||
events.push_back({
|
||||
{"event", "message_start"},
|
||||
{"data", {
|
||||
{"type", "message_start"},
|
||||
{"message", {
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"type", "message"},
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"model", oaicompat_model},
|
||||
{"stop_reason", nullptr},
|
||||
{"stop_sequence", nullptr},
|
||||
{"usage", {
|
||||
{"input_tokens", n_prompt_tokens},
|
||||
{"output_tokens", 0}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (!text_block_started) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", 0},
|
||||
{"content_block", {
|
||||
{"type", "text"},
|
||||
{"text", ""}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
text_block_started = true;
|
||||
}
|
||||
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", 0},
|
||||
{"delta", {
|
||||
{"type", "text_delta"},
|
||||
{"text", diff.content_delta}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (diff.tool_call_index != std::string::npos) {
|
||||
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
|
||||
|
||||
if (!diff.tool_call_delta.name.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_start"},
|
||||
{"data", {
|
||||
{"type", "content_block_start"},
|
||||
{"index", content_block_index},
|
||||
{"content_block", {
|
||||
{"type", "tool_use"},
|
||||
{"id", diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
|
||||
if (!diff.tool_call_delta.arguments.empty()) {
|
||||
events.push_back({
|
||||
{"event", "content_block_delta"},
|
||||
{"data", {
|
||||
{"type", "content_block_delta"},
|
||||
{"index", content_block_index},
|
||||
{"delta", {
|
||||
{"type", "input_json_delta"},
|
||||
{"partial_json", diff.tool_call_delta.arguments}
|
||||
}}
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
//
|
||||
// server_task_result_error
|
||||
//
|
||||
|
|
|
|||
|
|
@ -27,11 +27,12 @@ enum server_task_type {
|
|||
};
|
||||
|
||||
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
|
||||
enum oaicompat_type {
|
||||
OAICOMPAT_TYPE_NONE,
|
||||
OAICOMPAT_TYPE_CHAT,
|
||||
OAICOMPAT_TYPE_COMPLETION,
|
||||
OAICOMPAT_TYPE_EMBEDDING,
|
||||
enum task_response_type {
|
||||
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
|
||||
TASK_RESPONSE_TYPE_OAI_CHAT,
|
||||
TASK_RESPONSE_TYPE_OAI_CMPL,
|
||||
TASK_RESPONSE_TYPE_OAI_EMBD,
|
||||
TASK_RESPONSE_TYPE_ANTHROPIC,
|
||||
};
|
||||
|
||||
enum stop_type {
|
||||
|
|
@ -66,9 +67,9 @@ struct task_params {
|
|||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
|
@ -227,12 +228,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
task_params generation_params;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
|
|
@ -253,6 +254,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
json to_json_oaicompat_chat();
|
||||
|
||||
json to_json_oaicompat_chat_stream();
|
||||
|
||||
json to_json_anthropic();
|
||||
|
||||
json to_json_anthropic_stream();
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_partial : server_task_result {
|
||||
|
|
@ -270,11 +275,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
result_timings timings;
|
||||
result_prompt_progress progress;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
// response formatting
|
||||
bool verbose = false;
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
virtual int get_index() override {
|
||||
|
|
@ -292,6 +297,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
json to_json_oaicompat();
|
||||
|
||||
json to_json_oaicompat_chat();
|
||||
|
||||
json to_json_anthropic();
|
||||
};
|
||||
|
||||
struct server_task_result_embd : server_task_result {
|
||||
|
|
@ -300,8 +307,8 @@ struct server_task_result_embd : server_task_result {
|
|||
|
||||
int32_t n_tokens;
|
||||
|
||||
// OAI-compat fields
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
// response formatting
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
|
|
|
|||
|
|
@ -1255,7 +1255,7 @@ struct server_context {
|
|||
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.task->params.verbose;
|
||||
res->oaicompat = slot.task->params.oaicompat;
|
||||
res->res_type = slot.task->params.res_type;
|
||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||
|
||||
|
|
@ -1297,7 +1297,7 @@ struct server_context {
|
|||
res->verbose = slot.task->params.verbose;
|
||||
res->stream = slot.task->params.stream;
|
||||
res->include_usage = slot.task->params.include_usage;
|
||||
res->oaicompat = slot.task->params.oaicompat;
|
||||
res->res_type = slot.task->params.res_type;
|
||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
|
@ -1328,7 +1328,7 @@ struct server_context {
|
|||
res->id = slot.task->id;
|
||||
res->index = slot.task->index;
|
||||
res->n_tokens = slot.task->n_tokens();
|
||||
res->oaicompat = slot.task->params.oaicompat;
|
||||
res->res_type = slot.task->params.res_type;
|
||||
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
|
||||
|
|
@ -2951,7 +2951,7 @@ public:
|
|||
data,
|
||||
files,
|
||||
req.should_stop,
|
||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||
TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_completions = [this](const server_http_req & req) {
|
||||
|
|
@ -2962,7 +2962,7 @@ public:
|
|||
body,
|
||||
files,
|
||||
req.should_stop,
|
||||
OAICOMPAT_TYPE_NONE);
|
||||
TASK_RESPONSE_TYPE_NONE);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) {
|
||||
|
|
@ -2973,7 +2973,7 @@ public:
|
|||
body,
|
||||
files,
|
||||
req.should_stop,
|
||||
OAICOMPAT_TYPE_COMPLETION);
|
||||
TASK_RESPONSE_TYPE_OAI_CMPL);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) {
|
||||
|
|
@ -2988,7 +2988,38 @@ public:
|
|||
body_parsed,
|
||||
files,
|
||||
req.should_stop,
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
TASK_RESPONSE_TYPE_OAI_CHAT);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_anthropic_messages = [this](const server_http_req & req) {
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
body,
|
||||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body_parsed,
|
||||
files,
|
||||
req.should_stop,
|
||||
TASK_RESPONSE_TYPE_ANTHROPIC);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_anthropic_count_tokens = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
body,
|
||||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
|
||||
json prompt = body_parsed.at("prompt");
|
||||
llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true);
|
||||
|
||||
res->ok({{"input_tokens", static_cast<int>(tokens.size())}});
|
||||
return res;
|
||||
};
|
||||
|
||||
// same with handle_chat_completions, but without inference part
|
||||
|
|
@ -3107,11 +3138,11 @@ public:
|
|||
};
|
||||
|
||||
server_http_context::handler_t post_embeddings = [this](const server_http_req & req) {
|
||||
return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE);
|
||||
return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) {
|
||||
return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING);
|
||||
return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD);
|
||||
};
|
||||
|
||||
server_http_context::handler_t post_rerank = [this](const server_http_req & req) {
|
||||
|
|
@ -3262,7 +3293,7 @@ private:
|
|||
const json & data,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const std::function<bool()> & should_stop,
|
||||
oaicompat_type oaicompat) {
|
||||
task_response_type res_type) {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
|
|
@ -3279,7 +3310,7 @@ private:
|
|||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
|
||||
if (oaicompat && ctx_server.mctx != nullptr) {
|
||||
if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) {
|
||||
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
|
||||
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
|
||||
} else {
|
||||
|
|
@ -3301,8 +3332,8 @@ private:
|
|||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.res_type = res_type;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
|
|
@ -3352,10 +3383,14 @@ private:
|
|||
}
|
||||
|
||||
// next responses are streamed
|
||||
res->data = format_sse(first_result->to_json()); // to be sent immediately
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
res->data = format_anthropic_sse(first_result->to_json());
|
||||
} else {
|
||||
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
|
||||
}
|
||||
res->status = 200;
|
||||
res->content_type = "text/event-stream";
|
||||
res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool {
|
||||
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
|
|
@ -3372,7 +3407,10 @@ private:
|
|||
|
||||
// check if there is more data
|
||||
if (!rd.has_next()) {
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
// Anthropic doesn't send [DONE], message_stop was already sent
|
||||
output = "";
|
||||
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
||||
output = "data: [DONE]\n\n";
|
||||
} else {
|
||||
output = "";
|
||||
|
|
@ -3391,7 +3429,14 @@ private:
|
|||
// send the results
|
||||
json res_json = result->to_json();
|
||||
if (result->is_error()) {
|
||||
output = format_sse(json {{ "error", res_json }});
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse({
|
||||
{"event", "error"},
|
||||
{"data", res_json},
|
||||
});
|
||||
} else {
|
||||
output = format_oai_sse(json {{ "error", res_json }});
|
||||
}
|
||||
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
||||
return false; // terminate on error
|
||||
} else {
|
||||
|
|
@ -3399,7 +3444,11 @@ private:
|
|||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
output = format_sse(res_json);
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse(res_json);
|
||||
} else {
|
||||
output = format_oai_sse(res_json);
|
||||
}
|
||||
}
|
||||
|
||||
// has next data, continue
|
||||
|
|
@ -3507,14 +3556,14 @@ private:
|
|||
return res;
|
||||
}
|
||||
|
||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) {
|
||||
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
if (!ctx_server.params_base.embedding) {
|
||||
res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return res;
|
||||
}
|
||||
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
|
|
@ -3526,7 +3575,7 @@ private:
|
|||
if (body.count("input") != 0) {
|
||||
prompt = body.at("input");
|
||||
} else if (body.contains("content")) {
|
||||
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
||||
res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible
|
||||
prompt = body.at("content");
|
||||
} else {
|
||||
res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
|
|
@ -3574,7 +3623,7 @@ private:
|
|||
task.tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.res_type = res_type;
|
||||
task.params.embd_normalize = embd_normalize;
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
|
|
@ -3599,7 +3648,7 @@ private:
|
|||
}
|
||||
|
||||
// write JSON response
|
||||
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||
json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
|
||||
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
||||
: json(responses);
|
||||
res->ok(root);
|
||||
|
|
@ -3712,6 +3761,8 @@ int main(int argc, char ** argv) {
|
|||
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
|
||||
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
|
||||
ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
|
||||
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
||||
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
|
||||
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
|
||||
|
|
|
|||
|
|
@ -13,3 +13,9 @@ def stop_server_after_each_test():
|
|||
) # copy the set to prevent 'Set changed size during iteration'
|
||||
for server in instances:
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def do_something():
|
||||
# this will be run once per test session, before any tests
|
||||
ServerPreset.load_all()
|
||||
|
|
|
|||
|
|
@ -5,12 +5,6 @@ from utils import *
|
|||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def do_something():
|
||||
# this will be run once per test session, before any tests
|
||||
ServerPreset.load_all()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
|
|
|
|||
807
tools/server/tests/unit/test_compat_anthropic.py
Normal file
807
tools/server/tests/unit/test_compat_anthropic.py
Normal file
|
|
@ -0,0 +1,807 @@
|
|||
#!/usr/bin/env python3
|
||||
import pytest
|
||||
import base64
|
||||
import requests
|
||||
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
|
||||
def get_test_image_base64() -> str:
|
||||
"""Get a test image in base64 format"""
|
||||
# Use the same test image as test_vision_api.py
|
||||
IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
||||
response = requests.get(IMG_URL)
|
||||
response.raise_for_status()
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2-anthropic"
|
||||
server.server_port = 8082
|
||||
server.n_slots = 1
|
||||
server.n_ctx = 8192
|
||||
server.n_batch = 2048
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vision_server():
|
||||
"""Separate fixture for vision tests that require multimodal support"""
|
||||
global server
|
||||
server = ServerPreset.tinygemma3()
|
||||
server.offline = False # Allow downloading the model
|
||||
server.model_alias = "tinygemma3-anthropic"
|
||||
server.server_port = 8083 # Different port to avoid conflicts
|
||||
server.n_slots = 1
|
||||
return server
|
||||
|
||||
|
||||
# Basic message tests
|
||||
|
||||
def test_anthropic_messages_basic():
|
||||
"""Test basic Anthropic messages endpoint"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Say hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200, f"Expected 200, got {res.status_code}"
|
||||
assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}"
|
||||
assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}"
|
||||
assert "content" in res.body, "Missing 'content' field"
|
||||
assert isinstance(res.body["content"], list), "Content should be an array"
|
||||
assert len(res.body["content"]) > 0, "Content array should not be empty"
|
||||
assert res.body["content"][0]["type"] == "text", "First content block should be text"
|
||||
assert "text" in res.body["content"][0], "Text content block missing 'text' field"
|
||||
assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}"
|
||||
assert "usage" in res.body, "Missing 'usage' field"
|
||||
assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens"
|
||||
assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens"
|
||||
assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer"
|
||||
assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer"
|
||||
assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens"
|
||||
# Anthropic API should NOT include timings
|
||||
assert "timings" not in res.body, "Anthropic API should not include timings field"
|
||||
|
||||
|
||||
def test_anthropic_messages_with_system():
|
||||
"""Test messages with system prompt"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"system": "You are a helpful assistant.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
assert len(res.body["content"]) > 0
|
||||
|
||||
|
||||
def test_anthropic_messages_multipart_content():
|
||||
"""Test messages with multipart content blocks"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is"},
|
||||
{"type": "text", "text": " the answer?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_messages_conversation():
|
||||
"""Test multi-turn conversation"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
# Streaming tests
|
||||
|
||||
def test_anthropic_messages_streaming():
|
||||
"""Test streaming messages"""
|
||||
server.start()
|
||||
|
||||
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 30,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Say hello"}
|
||||
],
|
||||
"stream": True
|
||||
})
|
||||
|
||||
events = []
|
||||
for data in res:
|
||||
# Each event should have type and other fields
|
||||
assert "type" in data, f"Missing 'type' in event: {data}"
|
||||
events.append(data)
|
||||
|
||||
# Verify event sequence
|
||||
event_types = [e["type"] for e in events]
|
||||
assert "message_start" in event_types, "Missing message_start event"
|
||||
assert "content_block_start" in event_types, "Missing content_block_start event"
|
||||
assert "content_block_delta" in event_types, "Missing content_block_delta event"
|
||||
assert "content_block_stop" in event_types, "Missing content_block_stop event"
|
||||
assert "message_delta" in event_types, "Missing message_delta event"
|
||||
assert "message_stop" in event_types, "Missing message_stop event"
|
||||
|
||||
# Check message_start structure
|
||||
message_start = next(e for e in events if e["type"] == "message_start")
|
||||
assert "message" in message_start, "message_start missing 'message' field"
|
||||
assert message_start["message"]["type"] == "message"
|
||||
assert message_start["message"]["role"] == "assistant"
|
||||
assert message_start["message"]["content"] == []
|
||||
assert "usage" in message_start["message"]
|
||||
assert message_start["message"]["usage"]["input_tokens"] > 0
|
||||
|
||||
# Check content_block_start
|
||||
block_start = next(e for e in events if e["type"] == "content_block_start")
|
||||
assert "index" in block_start, "content_block_start missing 'index'"
|
||||
assert block_start["index"] == 0, "First content block should be at index 0"
|
||||
assert "content_block" in block_start
|
||||
assert block_start["content_block"]["type"] == "text"
|
||||
|
||||
# Check content_block_delta
|
||||
deltas = [e for e in events if e["type"] == "content_block_delta"]
|
||||
assert len(deltas) > 0, "Should have at least one content_block_delta"
|
||||
for delta in deltas:
|
||||
assert "index" in delta
|
||||
assert "delta" in delta
|
||||
assert delta["delta"]["type"] == "text_delta"
|
||||
assert "text" in delta["delta"]
|
||||
|
||||
# Check content_block_stop
|
||||
block_stop = next(e for e in events if e["type"] == "content_block_stop")
|
||||
assert "index" in block_stop
|
||||
assert block_stop["index"] == 0
|
||||
|
||||
# Check message_delta
|
||||
message_delta = next(e for e in events if e["type"] == "message_delta")
|
||||
assert "delta" in message_delta
|
||||
assert "stop_reason" in message_delta["delta"]
|
||||
assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"]
|
||||
assert "usage" in message_delta
|
||||
assert message_delta["usage"]["output_tokens"] > 0
|
||||
|
||||
# Check message_stop
|
||||
message_stop = next(e for e in events if e["type"] == "message_stop")
|
||||
# message_stop should NOT have timings for Anthropic API
|
||||
assert "timings" not in message_stop, "Anthropic streaming should not include timings"
|
||||
|
||||
|
||||
# Token counting tests
|
||||
|
||||
def test_anthropic_count_tokens():
|
||||
"""Test token counting endpoint"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||
"model": "test",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello world"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert "input_tokens" in res.body
|
||||
assert isinstance(res.body["input_tokens"], int)
|
||||
assert res.body["input_tokens"] > 0
|
||||
# Should only have input_tokens, no other fields
|
||||
assert "output_tokens" not in res.body
|
||||
|
||||
|
||||
def test_anthropic_count_tokens_with_system():
|
||||
"""Test token counting with system prompt"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||
"model": "test",
|
||||
"system": "You are a helpful assistant.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["input_tokens"] > 0
|
||||
|
||||
|
||||
def test_anthropic_count_tokens_no_max_tokens():
|
||||
"""Test that count_tokens doesn't require max_tokens"""
|
||||
server.start()
|
||||
|
||||
# max_tokens is NOT required for count_tokens
|
||||
res = server.make_request("POST", "/v1/messages/count_tokens", data={
|
||||
"model": "test",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert "input_tokens" in res.body
|
||||
|
||||
|
||||
# Tool use tests
|
||||
|
||||
def test_anthropic_tool_use_basic():
|
||||
"""Test basic tool use"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 200,
|
||||
"tools": [{
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City name"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather in Paris?"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
assert len(res.body["content"]) > 0
|
||||
|
||||
# Check if model used the tool (it might not always, depending on the model)
|
||||
content_types = [block.get("type") for block in res.body["content"]]
|
||||
|
||||
if "tool_use" in content_types:
|
||||
# Model used the tool
|
||||
assert res.body["stop_reason"] == "tool_use"
|
||||
|
||||
# Find the tool_use block
|
||||
tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use")
|
||||
assert "id" in tool_block
|
||||
assert "name" in tool_block
|
||||
assert tool_block["name"] == "get_weather"
|
||||
assert "input" in tool_block
|
||||
assert isinstance(tool_block["input"], dict)
|
||||
|
||||
|
||||
def test_anthropic_tool_result():
|
||||
"""Test sending tool results back
|
||||
|
||||
This test verifies that tool_result blocks are properly converted to
|
||||
role="tool" messages internally. Without proper conversion, this would
|
||||
fail with a 500 error: "unsupported content[].type" because tool_result
|
||||
blocks would remain in the user message content array.
|
||||
"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "test123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "test123",
|
||||
"content": "The weather is sunny, 25°C"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# This would be 500 with the old bug where tool_result blocks weren't converted
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
# Model should respond to the tool result
|
||||
assert len(res.body["content"]) > 0
|
||||
assert res.body["content"][0]["type"] == "text"
|
||||
|
||||
|
||||
def test_anthropic_tool_result_with_text():
|
||||
"""Test tool result mixed with text content
|
||||
|
||||
This tests the edge case where a user message contains both text and
|
||||
tool_result blocks. The server must properly split these into separate
|
||||
messages: a user message with text, followed by tool messages.
|
||||
Without proper handling, this would fail with 500: "unsupported content[].type"
|
||||
"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tool_1",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "Paris"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here are the results:"},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "tool_1",
|
||||
"content": "Sunny, 25°C"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
assert len(res.body["content"]) > 0
|
||||
|
||||
|
||||
def test_anthropic_tool_result_error():
|
||||
"""Test tool result with error flag"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Get the weather"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "test123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "InvalidCity"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "test123",
|
||||
"is_error": True,
|
||||
"content": "City not found"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_tool_streaming():
|
||||
"""Test streaming with tool use"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 200,
|
||||
"stream": True,
|
||||
"tools": [{
|
||||
"name": "calculator",
|
||||
"description": "Calculate math",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Calculate 2+2"}
|
||||
]
|
||||
})
|
||||
|
||||
events = []
|
||||
for data in res:
|
||||
events.append(data)
|
||||
|
||||
event_types = [e["type"] for e in events]
|
||||
|
||||
# Should have basic events
|
||||
assert "message_start" in event_types
|
||||
assert "message_stop" in event_types
|
||||
|
||||
# If tool was used, check for proper tool streaming
|
||||
if any(e.get("type") == "content_block_start" and
|
||||
e.get("content_block", {}).get("type") == "tool_use"
|
||||
for e in events):
|
||||
# Find tool use block start
|
||||
tool_starts = [e for e in events if
|
||||
e.get("type") == "content_block_start" and
|
||||
e.get("content_block", {}).get("type") == "tool_use"]
|
||||
|
||||
assert len(tool_starts) > 0, "Should have tool_use content_block_start"
|
||||
|
||||
# Check index is correct (should be 0 if no text, 1 if there's text)
|
||||
tool_start = tool_starts[0]
|
||||
assert "index" in tool_start
|
||||
assert tool_start["content_block"]["type"] == "tool_use"
|
||||
assert "name" in tool_start["content_block"]
|
||||
|
||||
|
||||
# Vision/multimodal tests
|
||||
|
||||
def test_anthropic_vision_format_accepted():
|
||||
"""Test that Anthropic vision format is accepted (format validation only)"""
|
||||
server.start()
|
||||
|
||||
# Small 1x1 red PNG image in base64
|
||||
red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 10,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": red_pixel_png
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is this?"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Server accepts the format but tinyllama doesn't support images
|
||||
# So it should return 500 with clear error message about missing mmproj
|
||||
assert res.status_code == 500
|
||||
assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower()
|
||||
|
||||
|
||||
def test_anthropic_vision_base64_with_multimodal_model(vision_server):
|
||||
"""Test vision with base64 image using Anthropic format with multimodal model"""
|
||||
global server
|
||||
server = vision_server
|
||||
server.start()
|
||||
|
||||
# Get test image in base64 format
|
||||
image_base64 = get_test_image_base64()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 10,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": image_base64
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is this:\n"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}"
|
||||
assert res.body["type"] == "message"
|
||||
assert len(res.body["content"]) > 0
|
||||
assert res.body["content"][0]["type"] == "text"
|
||||
# The model should generate some response about the image
|
||||
assert len(res.body["content"][0]["text"]) > 0
|
||||
|
||||
|
||||
# Parameter tests
|
||||
|
||||
def test_anthropic_stop_sequences():
|
||||
"""Test stop_sequences parameter"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"stop_sequences": ["\n", "END"],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Count to 10"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_temperature():
|
||||
"""Test temperature parameter"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.5,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_top_p():
|
||||
"""Test top_p parameter"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"top_p": 0.9,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_top_k():
|
||||
"""Test top_k parameter (llama.cpp specific)"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"top_k": 40,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
# Error handling tests
|
||||
|
||||
def test_anthropic_missing_messages():
|
||||
"""Test error when messages are missing"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50
|
||||
# missing "messages" field
|
||||
})
|
||||
|
||||
# Should return an error (400 or 500)
|
||||
assert res.status_code >= 400
|
||||
|
||||
|
||||
def test_anthropic_empty_messages():
|
||||
"""Test permissive handling of empty messages array"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": []
|
||||
})
|
||||
|
||||
# Server is permissive and accepts empty messages (provides defaults)
|
||||
# This matches the permissive validation design choice
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
# Content block index tests
|
||||
|
||||
def test_anthropic_streaming_content_block_indices():
|
||||
"""Test that content block indices are correct in streaming"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
# Request that might produce both text and tool use
|
||||
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 200,
|
||||
"stream": True,
|
||||
"tools": [{
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param": {"type": "string"}
|
||||
},
|
||||
"required": ["param"]
|
||||
}
|
||||
}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "Use the test tool"}
|
||||
]
|
||||
})
|
||||
|
||||
events = []
|
||||
for data in res:
|
||||
events.append(data)
|
||||
|
||||
# Check content_block_start events have sequential indices
|
||||
block_starts = [e for e in events if e.get("type") == "content_block_start"]
|
||||
if len(block_starts) > 1:
|
||||
# If there are multiple blocks, indices should be sequential
|
||||
indices = [e["index"] for e in block_starts]
|
||||
expected_indices = list(range(len(block_starts)))
|
||||
assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}"
|
||||
|
||||
# Check content_block_stop events match the starts
|
||||
block_stops = [e for e in events if e.get("type") == "content_block_stop"]
|
||||
start_indices = set(e["index"] for e in block_starts)
|
||||
stop_indices = set(e["index"] for e in block_stops)
|
||||
assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices"
|
||||
|
||||
|
||||
# Extended features tests
|
||||
|
||||
def test_anthropic_thinking():
|
||||
"""Test extended thinking parameter"""
|
||||
server.jinja = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 100,
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 50
|
||||
},
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
def test_anthropic_metadata():
|
||||
"""Test metadata parameter"""
|
||||
server.start()
|
||||
|
||||
res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"metadata": {
|
||||
"user_id": "test_user_123"
|
||||
},
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["type"] == "message"
|
||||
|
||||
|
||||
# Compatibility tests
|
||||
|
||||
def test_anthropic_vs_openai_different_response_format():
|
||||
"""Verify Anthropic format is different from OpenAI format"""
|
||||
server.start()
|
||||
|
||||
# Make OpenAI request
|
||||
openai_res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
# Make Anthropic request
|
||||
anthropic_res = server.make_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 50,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
})
|
||||
|
||||
assert openai_res.status_code == 200
|
||||
assert anthropic_res.status_code == 200
|
||||
|
||||
# OpenAI has "object", Anthropic has "type"
|
||||
assert "object" in openai_res.body
|
||||
assert "type" in anthropic_res.body
|
||||
assert openai_res.body["object"] == "chat.completion"
|
||||
assert anthropic_res.body["type"] == "message"
|
||||
|
||||
# OpenAI has "choices", Anthropic has "content"
|
||||
assert "choices" in openai_res.body
|
||||
assert "content" in anthropic_res.body
|
||||
|
||||
# Different usage field names
|
||||
assert "prompt_tokens" in openai_res.body["usage"]
|
||||
assert "input_tokens" in anthropic_res.body["usage"]
|
||||
assert "completion_tokens" in openai_res.body["usage"]
|
||||
assert "output_tokens" in anthropic_res.body["usage"]
|
||||
|
|
@ -49,6 +49,19 @@ def test_correct_api_key():
|
|||
assert "content" in res.body
|
||||
|
||||
|
||||
def test_correct_api_key_anthropic_header():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
}, headers={
|
||||
"X-Api-Key": TEST_API_KEY,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "error" not in res.body
|
||||
assert "content" in res.body
|
||||
|
||||
|
||||
def test_openai_library_correct_api_key():
|
||||
global server
|
||||
server.start()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue