mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-21 02:19:27 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # README.md # docs/backend/zDNN.md # docs/build.md # docs/ops.md # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt # ggml/src/ggml-cann/ggml-cann.cpp # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-rpc/ggml-rpc.cpp # ggml/src/ggml-sycl/convert.cpp # ggml/src/ggml-sycl/ggml-sycl.cpp # src/llama-quant.cpp # tests/test-backend-ops.cpp # tools/llama-bench/llama-bench.cpp # tools/server/README.md
This commit is contained in:
commit
17c0c8d55d
40 changed files with 1161 additions and 326 deletions
|
|
@ -710,6 +710,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.use_jinja = true;
|
||||
}
|
||||
|
||||
params.use_color = tty_can_use_colors();
|
||||
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
||||
|
|
@ -792,10 +794,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
add_opt(common_arg(
|
||||
{"-co", "--color"},
|
||||
string_format("colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.use_color = true;
|
||||
{"-co", "--color"}, "[on|off|auto]",
|
||||
"Colorize output to distinguish prompt and user input from generations ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
params.use_color = true;
|
||||
} else if (is_falsey(value)) {
|
||||
params.use_color = false;
|
||||
} else if (is_autoy(value)) {
|
||||
params.use_color = tty_can_use_colors();
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unknown value for --color: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
|
||||
add_opt(common_arg(
|
||||
|
|
@ -1024,7 +1036,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
|
||||
string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_ARG_FLASH_ATTN"));
|
||||
add_opt(common_arg(
|
||||
|
|
@ -2698,7 +2710,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
string_format("error: unknown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_LOG_COLORS"));
|
||||
|
|
|
|||
|
|
@ -990,6 +990,32 @@ std::vector<common_file_info> fs_list(const std::string & path, bool include_dir
|
|||
return files;
|
||||
}
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
bool tty_can_use_colors() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
//
|
||||
// Model utils
|
||||
|
|
|
|||
|
|
@ -651,6 +651,13 @@ struct common_file_info {
|
|||
};
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
// Auto-detect if colors can be enabled based on terminal and environment
|
||||
bool tty_can_use_colors();
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <chrono>
|
||||
|
|
@ -26,30 +27,6 @@ void common_log_set_verbosity_thold(int verbosity) {
|
|||
common_log_verbosity_thold = verbosity;
|
||||
}
|
||||
|
||||
// Auto-detect if colors should be enabled based on terminal and environment
|
||||
static bool common_log_should_use_colors_auto() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
static int64_t t_us() {
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
}
|
||||
|
|
@ -391,7 +368,7 @@ struct common_log * common_log_main() {
|
|||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, [&]() {
|
||||
// Set default to auto-detect colors
|
||||
log.set_colors(common_log_should_use_colors_auto());
|
||||
log.set_colors(tty_can_use_colors());
|
||||
});
|
||||
|
||||
return &log;
|
||||
|
|
@ -422,7 +399,7 @@ void common_log_set_file(struct common_log * log, const char * file) {
|
|||
|
||||
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
||||
if (colors == LOG_COLORS_AUTO) {
|
||||
log->set_colors(common_log_should_use_colors_auto());
|
||||
log->set_colors(tty_can_use_colors());
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1524,6 +1524,79 @@ class TextModel(ModelBase):
|
|||
special_vocab._set_special_token("bos", 151643)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_mistral(self):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
vocab = MistralVocab(self.dir_model)
|
||||
logger.info(
|
||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
||||
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
assert len(tokens) == vocab.vocab_size, (
|
||||
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
||||
)
|
||||
|
||||
if vocab.tokenizer_type == MistralTokenizerType.tekken:
|
||||
self.gguf_writer.add_tokenizer_pre("tekken")
|
||||
self.gguf_writer.add_token_merges(
|
||||
vocab.extract_vocab_merges_from_model()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
||||
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
||||
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
||||
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(False)
|
||||
|
||||
local_template_file_path = self.dir_model / "chat_template.jinja"
|
||||
|
||||
if self.is_mistral_format and local_template_file_path.is_file():
|
||||
# Ministral-3 and other new Mistral models come with chat templates.
|
||||
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
|
||||
logger.info("Using an existing Mistral local chat template.")
|
||||
|
||||
with open(local_template_file_path, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||
template_dir = Path(__file__).parent / "models/templates/"
|
||||
|
||||
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
|
||||
if self.is_mistral_format:
|
||||
logger.info(
|
||||
"Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. "
|
||||
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
|
||||
)
|
||||
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
|
||||
else:
|
||||
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||
template = None
|
||||
|
||||
if template is not None:
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
|
||||
|
||||
class MmprojModel(ModelBase):
|
||||
model_type = ModelType.MMPROJ
|
||||
|
|
@ -2294,79 +2367,6 @@ class LlamaModel(TextModel):
|
|||
if self.hf_arch == "VLlama3ForCausalLM":
|
||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||
|
||||
def _set_vocab_mistral(self):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
vocab = MistralVocab(self.dir_model)
|
||||
logger.info(
|
||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
||||
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
assert len(tokens) == vocab.vocab_size, (
|
||||
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
||||
)
|
||||
|
||||
if vocab.tokenizer_type == MistralTokenizerType.tekken:
|
||||
self.gguf_writer.add_tokenizer_pre("tekken")
|
||||
self.gguf_writer.add_token_merges(
|
||||
vocab.extract_vocab_merges_from_model()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
||||
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
||||
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
||||
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(False)
|
||||
|
||||
local_template_file_path = self.dir_model / "chat_template.jinja"
|
||||
|
||||
if self.is_mistral_format and local_template_file_path.is_file():
|
||||
# Ministral-3 and other new Mistral models come with chat templates.
|
||||
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
|
||||
logger.info("Using an existing Mistral local chat template.")
|
||||
|
||||
with open(local_template_file_path, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||
template_dir = Path(__file__).parent / "models/templates/"
|
||||
|
||||
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
|
||||
if self.is_mistral_format:
|
||||
logger.info(
|
||||
"Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. "
|
||||
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
|
||||
)
|
||||
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
|
||||
else:
|
||||
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||
template = None
|
||||
|
||||
if template is not None:
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
|
||||
def set_vocab(self):
|
||||
if self.is_mistral_format:
|
||||
return self._set_vocab_mistral()
|
||||
|
|
@ -9924,17 +9924,109 @@ class MistralModel(LlamaModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if "yarn" in self.hparams:
|
||||
yarn_params = self.hparams["yarn"]
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
|
||||
|
||||
if "llama_4_scaling" in self.hparams:
|
||||
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
|
||||
@staticmethod
|
||||
def set_mistral_config(gguf_writer: gguf.GGUFWriter, hparams: dict):
|
||||
if "yarn" in hparams:
|
||||
yarn_params = hparams["yarn"]
|
||||
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
|
||||
if "llama_4_scaling" in hparams:
|
||||
gguf_writer.add_attn_temperature_scale(hparams["llama_4_scaling"]["beta"])
|
||||
|
||||
|
||||
class MistralMoeModel(DeepseekV2Model):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
model_name = "Mistral"
|
||||
hf_arch = ""
|
||||
is_mistral_format = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
logger.info("Using MistralMoeModel")
|
||||
# remap hparams from Mistral MoE format to DeepseekV2 format
|
||||
# we do this way to be able to reuse DeepseekV2Model set_gguf_parameters logic
|
||||
# ref: https://github.com/vllm-project/vllm/blob/b294e28db2c5dee61bc25157664edcada8b90b31/vllm/transformers_utils/configs/mistral.py
|
||||
config = self.hparams
|
||||
# Mistral key -> HF key
|
||||
config_mapping = {
|
||||
"dim": "hidden_size",
|
||||
"norm_eps": "rms_norm_eps",
|
||||
"n_kv_heads": "num_key_value_heads",
|
||||
"n_layers": "num_hidden_layers",
|
||||
"n_heads": "num_attention_heads",
|
||||
"hidden_dim": "intermediate_size",
|
||||
}
|
||||
# HF key -> (Mistral key, default value)
|
||||
top_level_mapping_with_default = {
|
||||
"model_type": ("model_type", "transformer"),
|
||||
"hidden_act": ("activation", "silu"),
|
||||
"tie_word_embeddings": ("tied_embeddings", False),
|
||||
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
|
||||
"max_position_embeddings": ("max_position_embeddings", 128_000),
|
||||
}
|
||||
# mapping top-level keys
|
||||
for key, new_key in config_mapping.items():
|
||||
if key in config:
|
||||
config[new_key] = config[key]
|
||||
for new_key, (key, default_value) in top_level_mapping_with_default.items():
|
||||
config[new_key] = config.get(key, default_value)
|
||||
# mapping MoE-specific keys
|
||||
moe_config_map = {
|
||||
"route_every_n": "moe_layer_freq",
|
||||
"first_k_dense_replace": "first_k_dense_replace",
|
||||
"num_experts_per_tok": "num_experts_per_tok",
|
||||
"num_experts": "n_routed_experts",
|
||||
"expert_hidden_dim": "moe_intermediate_size",
|
||||
"routed_scale": "routed_scaling_factor",
|
||||
"num_shared_experts": "n_shared_experts",
|
||||
"num_expert_groups": "n_group",
|
||||
"num_expert_groups_per_tok": "topk_group",
|
||||
}
|
||||
moe = config["moe"]
|
||||
for key, new_key in moe_config_map.items():
|
||||
if key in moe:
|
||||
config[new_key] = moe[key]
|
||||
# provide missing values
|
||||
config["topk_method"] = None
|
||||
config["norm_topk_prob"] = True
|
||||
config["scoring_func"] = "softmax"
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_mistral()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
|
||||
yarn_params = self.hparams["yarn"]
|
||||
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name:
|
||||
return []
|
||||
|
||||
# rename certain tensors so that we can reuse DeepseekV2Model modify_tensors logic
|
||||
if name.endswith(".qscale_act"):
|
||||
name = name.replace(".qscale_act", ".input_scale")
|
||||
if name.endswith(".qscale_weight"):
|
||||
name = name.replace(".qscale_weight", ".weight_scale")
|
||||
if ".wkv_b." in name:
|
||||
name = name.replace(".wkv_b.", ".kv_b_proj.")
|
||||
if ".experts." in name:
|
||||
name = name.replace(".experts.", ".mlp.experts.")
|
||||
name = name.replace(".w1.", ".gate_proj.")
|
||||
name = name.replace(".w2.", ".down_proj.")
|
||||
name = name.replace(".w3.", ".up_proj.")
|
||||
name = "model." + name
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
class PixtralModel(LlavaVisionModel):
|
||||
|
|
@ -10490,6 +10582,8 @@ def main() -> None:
|
|||
elif args.mmproj:
|
||||
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
|
||||
model_class = PixtralModel
|
||||
elif "moe" in hparams:
|
||||
model_class = MistralMoeModel
|
||||
else:
|
||||
model_class = MistralModel
|
||||
|
||||
|
|
|
|||
22
ggml/include/ggml-zendnn.h
Normal file
22
ggml/include/ggml-zendnn.h
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// backend API
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_zendnn_init(void);
|
||||
|
||||
GGML_BACKEND_API bool ggml_backend_is_zendnn(ggml_backend_t backend);
|
||||
|
||||
// number of threads used for zendnn operations
|
||||
GGML_BACKEND_API void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zendnn_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -2221,6 +2221,15 @@ extern "C" {
|
|||
int p2,
|
||||
int p3);
|
||||
|
||||
// pad each dimension with values on the other side of the torus (looping around)
|
||||
GGML_API struct ggml_tensor * ggml_pad_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_pad_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
@ -2234,6 +2243,19 @@ extern "C" {
|
|||
int rp3
|
||||
);
|
||||
|
||||
// pad each dimension with values on the other side of the torus (looping around)
|
||||
GGML_API struct ggml_tensor * ggml_pad_ext_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int lp0,
|
||||
int rp0,
|
||||
int lp1,
|
||||
int rp1,
|
||||
int lp2,
|
||||
int rp2,
|
||||
int lp3,
|
||||
int rp3);
|
||||
|
||||
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
|
||||
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
|
||||
struct ggml_context * ctx,
|
||||
|
|
|
|||
|
|
@ -73,6 +73,10 @@
|
|||
#include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_ZENDNN
|
||||
#include "ggml-zendnn.h"
|
||||
#endif
|
||||
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic push
|
||||
|
|
@ -203,6 +207,9 @@ struct ggml_backend_registry {
|
|||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_ZENDNN
|
||||
register_backend(ggml_backend_zendnn_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_HEXAGON
|
||||
register_backend(ggml_backend_hexagon_reg());
|
||||
#endif
|
||||
|
|
@ -535,8 +542,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
|||
fs::path best_path;
|
||||
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (!fs::exists(search_path)) {
|
||||
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
||||
if (std::error_code ec; !fs::exists(search_path, ec)) {
|
||||
if (ec) {
|
||||
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
|
||||
} else {
|
||||
GGML_LOG_DEBUG("%s: search path %s does not exist\n", __func__, path_str(search_path).c_str());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
|
|
@ -576,8 +587,12 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
|||
for (const auto & search_path : search_paths) {
|
||||
fs::path filename = backend_filename_prefix().native() + name_path.native() + backend_filename_extension().native();
|
||||
fs::path path = search_path / filename;
|
||||
if (fs::exists(path)) {
|
||||
if (std::error_code ec; fs::exists(path, ec)) {
|
||||
return get_reg().load_backend(path, silent);
|
||||
} else {
|
||||
if (ec) {
|
||||
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(path).c_str(), ec.message().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
|
@ -598,6 +613,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|||
#endif
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("zendnn", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
|
|
|
|||
|
|
@ -6554,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
|
|||
ggml_compute_forward_mul_mat(params, &dst);
|
||||
}
|
||||
|
||||
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
|
||||
return (coord + size) % size; // adding size avoids negative number weirdness
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_2d
|
||||
|
||||
|
||||
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
|
||||
const ggml_tensor * kernel, // [KW, KH, IC, OC]
|
||||
const ggml_tensor * src, // [W, H, C, N]
|
||||
|
|
@ -7591,6 +7596,7 @@ void ggml_compute_forward_upscale(
|
|||
|
||||
// ggml_compute_forward_pad
|
||||
|
||||
template<bool circular_t>
|
||||
static void ggml_compute_forward_pad_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
|
@ -7615,23 +7621,40 @@ static void ggml_compute_forward_pad_f32(
|
|||
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
||||
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
||||
|
||||
|
||||
// TODO: optimize
|
||||
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
||||
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
||||
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
||||
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
||||
// circular means wrap around on a torus, so x and y loop around
|
||||
if constexpr (circular_t) {
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
|
||||
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
|
||||
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
|
||||
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
|
||||
|
||||
const int64_t src_idx =
|
||||
src_i3*nb03 +
|
||||
src_i2*nb02 +
|
||||
src_i1*nb01 +
|
||||
src_i0*nb00;
|
||||
|
||||
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
||||
dst_ptr[dst_idx] = *src_ptr;
|
||||
} else {
|
||||
dst_ptr[dst_idx] = 0;
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
||||
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
||||
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
||||
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
||||
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
||||
dst_ptr[dst_idx] = *src_ptr;
|
||||
} else {
|
||||
dst_ptr[dst_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -7639,16 +7662,20 @@ static void ggml_compute_forward_pad_f32(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_compute_forward_pad(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_pad_f32(params, dst);
|
||||
if (circular) {
|
||||
ggml_compute_forward_pad_f32<true>(params, dst);
|
||||
} else {
|
||||
ggml_compute_forward_pad_f32<false>(params, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||
return ampere_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc));
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,17 @@
|
|||
#include "pad.cuh"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
|
||||
// + size ensures negatives are handled properly
|
||||
return (coord + size) % size;
|
||||
}
|
||||
|
||||
static __global__ void pad_f32(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const bool circular) {
|
||||
// blockIdx.z: i3*ne2+i2
|
||||
// blockIdx.y: i1
|
||||
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
|
||||
|
|
@ -12,61 +20,84 @@ static __global__ void pad_f32(const float * src, float * dst,
|
|||
int i1 = blockIdx.y;
|
||||
int i2 = blockIdx.z % ne2;
|
||||
int i3 = blockIdx.z / ne2;
|
||||
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;
|
||||
|
||||
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
|
||||
if (!circular) {
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
// circular means on a torus, so x and y wrap around
|
||||
else {
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne03 = ne3 - lp3 - rp3;
|
||||
|
||||
const int64_t i00 = wrap_around(i0 - lp0, ne00);
|
||||
const int64_t i01 = wrap_around(i1 - lp1, ne01);
|
||||
const int64_t i02 = wrap_around(i2 - lp2, ne02);
|
||||
const int64_t i03 = wrap_around(i3 - lp3, ne03);
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void pad_f32_cuda(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const bool circular, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
ne0, ne1, ne2, ne3, circular);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
||||
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t *) (dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t *) (dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t *) (dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t *) (dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t *) (dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
|
||||
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
|
||||
|
||||
pad_f32_cuda(src0_d, dst_d,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
(bool) circular, stream);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1037,6 +1037,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_POOL_2D:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
|
|
|
|||
|
|
@ -793,11 +793,6 @@ struct vk_device_struct {
|
|||
std::unique_ptr<vk_memory_logger> memory_logger;
|
||||
#endif
|
||||
|
||||
// for GGML_VK_PERF_LOGGER
|
||||
std::unique_ptr<vk_perf_logger> perf_logger;
|
||||
vk::QueryPool query_pool;
|
||||
int32_t num_queries;
|
||||
|
||||
~vk_device_struct() {
|
||||
VK_LOG_DEBUG("destroy device " << name);
|
||||
|
||||
|
|
@ -1066,6 +1061,7 @@ struct vk_op_pad_push_constants {
|
|||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||
uint32_t misalign_offsets;
|
||||
uint32_t circular;
|
||||
|
||||
uint32_t lp0; uint32_t rp0;
|
||||
uint32_t lp1; uint32_t rp1;
|
||||
|
|
@ -1108,6 +1104,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor
|
|||
p.rp2 = dst->op_params[5];
|
||||
p.lp3 = dst->op_params[6];
|
||||
p.rp3 = dst->op_params[7];
|
||||
p.circular = dst->op_params[8];
|
||||
|
||||
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
||||
}
|
||||
|
|
@ -1537,12 +1534,21 @@ private:
|
|||
#define VK_LOG_MEMORY(msg) ((void) 0)
|
||||
#endif // GGML_VULKAN_MEMORY_DEBUG
|
||||
|
||||
static bool vk_perf_logger_enabled = false;
|
||||
// number of calls between perf logger prints
|
||||
static uint32_t vk_perf_logger_frequency = 1;
|
||||
|
||||
class vk_perf_logger {
|
||||
public:
|
||||
void print_timings() {
|
||||
void print_timings(bool force = false) {
|
||||
if (timings.empty()) {
|
||||
return;
|
||||
}
|
||||
print_count++;
|
||||
if ((print_count % vk_perf_logger_frequency) != 0 && !force) {
|
||||
return;
|
||||
}
|
||||
print_count = 0;
|
||||
uint64_t total_all_op_times = 0;
|
||||
std::cerr << "----------------\nVulkan Timings:" << std::endl;
|
||||
for (const auto & t : timings) {
|
||||
|
|
@ -1579,16 +1585,20 @@ class vk_perf_logger {
|
|||
flops.clear();
|
||||
}
|
||||
|
||||
void log_timing(const ggml_tensor * node, uint64_t time) {
|
||||
void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
|
||||
std::string fusion_str;
|
||||
if (fusion_name) {
|
||||
fusion_str = fusion_name + std::string(" ");
|
||||
}
|
||||
if (node->op == GGML_OP_UNARY) {
|
||||
timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
|
||||
timings[fusion_str + ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
|
||||
return;
|
||||
}
|
||||
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
||||
const uint64_t m = node->src[0]->ne[1];
|
||||
const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2];
|
||||
const uint64_t m = node->ne[0];
|
||||
const uint64_t n = node->ne[1];
|
||||
const uint64_t k = node->src[1]->ne[0];
|
||||
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
|
||||
const uint64_t batch = node->ne[2] * node->ne[3];
|
||||
std::string name = ggml_op_name(node->op);
|
||||
if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||
|
||||
(node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {
|
||||
|
|
@ -1597,9 +1607,13 @@ class vk_perf_logger {
|
|||
name += " ";
|
||||
name += ggml_type_name(node->src[0]->type);
|
||||
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
|
||||
if (node->op == GGML_OP_MUL_MAT_ID) {
|
||||
name += " n_expert=" + std::to_string(node->src[0]->ne[2]);
|
||||
}
|
||||
if (batch > 1) {
|
||||
name += " batch=" + std::to_string(batch);
|
||||
}
|
||||
name = fusion_str + name;
|
||||
timings[name].push_back(time);
|
||||
flops[name].push_back(m * n * (k + (k - 1)) * batch);
|
||||
return;
|
||||
|
|
@ -1621,6 +1635,7 @@ class vk_perf_logger {
|
|||
uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
|
||||
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
|
||||
", N=N*OW*OH=" + std::to_string(size_N);
|
||||
name = fusion_str + name;
|
||||
flops[name].push_back(n_flops);
|
||||
timings[name].push_back(time);
|
||||
return;
|
||||
|
|
@ -1628,6 +1643,7 @@ class vk_perf_logger {
|
|||
if (node->op == GGML_OP_RMS_NORM) {
|
||||
std::string name = ggml_op_name(node->op);
|
||||
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
|
||||
name = fusion_str + name;
|
||||
timings[name].push_back(time);
|
||||
return;
|
||||
}
|
||||
|
|
@ -1638,6 +1654,7 @@ class vk_perf_logger {
|
|||
const ggml_tensor * v = node->src[2];
|
||||
const ggml_tensor * m = node->src[3];
|
||||
std::stringstream name;
|
||||
name << fusion_str;
|
||||
name << ggml_op_name(node->op) <<
|
||||
" dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " <<
|
||||
" q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " <<
|
||||
|
|
@ -1649,17 +1666,19 @@ class vk_perf_logger {
|
|||
}
|
||||
if (node->op == GGML_OP_TOP_K) {
|
||||
std::stringstream name;
|
||||
name << fusion_str;
|
||||
name << ggml_op_name(node->op) <<
|
||||
" K=" << node->ne[0] <<
|
||||
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
|
||||
timings[name.str()].push_back(time);
|
||||
return;
|
||||
}
|
||||
timings[ggml_op_name(node->op)].push_back(time);
|
||||
timings[fusion_str + ggml_op_name(node->op)].push_back(time);
|
||||
}
|
||||
private:
|
||||
std::map<std::string, std::vector<uint64_t>> timings;
|
||||
std::map<std::string, std::vector<uint64_t>> flops;
|
||||
uint32_t print_count {};
|
||||
};
|
||||
|
||||
struct ggml_backend_vk_context {
|
||||
|
|
@ -1713,6 +1732,14 @@ struct ggml_backend_vk_context {
|
|||
// Bit 'i' means nodes[start_of_fusion + i] writes to memory.
|
||||
// If there's no fusion, bit 0 is still set.
|
||||
int fused_ops_write_mask {};
|
||||
|
||||
// for GGML_VK_PERF_LOGGER
|
||||
std::unique_ptr<vk_perf_logger> perf_logger;
|
||||
vk::QueryPool query_pool;
|
||||
std::vector<const char *> query_fusion_names;
|
||||
std::vector<ggml_tensor *> query_nodes;
|
||||
int32_t num_queries {};
|
||||
int32_t query_idx {};
|
||||
};
|
||||
|
||||
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
||||
|
|
@ -1838,8 +1865,6 @@ struct vk_instance_t {
|
|||
static bool vk_instance_initialized = false;
|
||||
static vk_instance_t vk_instance;
|
||||
|
||||
static bool vk_perf_logger_enabled = false;
|
||||
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
static size_t vk_skip_checks;
|
||||
static size_t vk_output_tensor;
|
||||
|
|
@ -3553,7 +3578,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
SHADER_REDUCTION_MODE_SHMEM;
|
||||
|
||||
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
||||
|
|
@ -3577,7 +3602,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
||||
|
|
@ -3623,7 +3648,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||
|
|
@ -4219,9 +4244,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||
device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
|
||||
#endif
|
||||
if (vk_perf_logger_enabled) {
|
||||
device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
|
||||
}
|
||||
|
||||
size_t dev_num = vk_instance.device_indices[idx];
|
||||
|
||||
|
|
@ -5181,6 +5203,11 @@ static void ggml_vk_instance_init() {
|
|||
}
|
||||
|
||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||
const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
|
||||
|
||||
if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
|
||||
vk_perf_logger_frequency = std::stoul(GGML_VK_PERF_LOGGER_FREQUENCY);
|
||||
}
|
||||
|
||||
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||
VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance);
|
||||
|
|
@ -5358,6 +5385,10 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
|||
ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
|
||||
ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
|
||||
vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
|
||||
|
|
@ -12233,6 +12264,9 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
|||
|
||||
ctx->compute_cmd_pool.destroy(ctx->device->device);
|
||||
ctx->transfer_cmd_pool.destroy(ctx->device->device);
|
||||
if (vk_perf_logger_enabled) {
|
||||
ctx->perf_logger->print_timings(true);
|
||||
}
|
||||
}
|
||||
|
||||
static int ggml_vk_get_device_count() {
|
||||
|
|
@ -13031,24 +13065,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
vk_context compute_ctx;
|
||||
if (vk_perf_logger_enabled) {
|
||||
// allocate/resize the query pool
|
||||
if (ctx->device->num_queries < cgraph->n_nodes + 1) {
|
||||
if (ctx->device->query_pool) {
|
||||
ctx->device->device.destroyQueryPool(ctx->device->query_pool);
|
||||
if (ctx->num_queries < cgraph->n_nodes + 1) {
|
||||
if (ctx->query_pool) {
|
||||
ctx->device->device.destroyQueryPool(ctx->query_pool);
|
||||
}
|
||||
vk::QueryPoolCreateInfo query_create_info;
|
||||
query_create_info.queryType = vk::QueryType::eTimestamp;
|
||||
query_create_info.queryCount = cgraph->n_nodes + 100;
|
||||
ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info);
|
||||
ctx->device->num_queries = query_create_info.queryCount;
|
||||
ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);
|
||||
ctx->num_queries = query_create_info.queryCount;
|
||||
ctx->query_fusion_names.resize(ctx->num_queries);
|
||||
ctx->query_nodes.resize(ctx->num_queries);
|
||||
}
|
||||
|
||||
ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1);
|
||||
ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);
|
||||
std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);
|
||||
std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);
|
||||
|
||||
GGML_ASSERT(ctx->compute_ctx.expired());
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
|
||||
ctx->query_idx = 0;
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
||||
}
|
||||
|
||||
ctx->prealloc_y_last_pipeline_used = nullptr;
|
||||
|
|
@ -13089,52 +13128,66 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
total_mul_mat_bytes += bytes;
|
||||
}
|
||||
|
||||
const char *fusion_string {};
|
||||
if (!ctx->device->disable_fusion) {
|
||||
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
||||
if (num_adds) {
|
||||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
fusion_string = "MULTI_ADD";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "MUL_MAT_ADD_ADD";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "MUL_MAT_ADD";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "MUL_MAT_ID_ADD_ID";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "MUL_MAT_ID_MUL";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
|
||||
ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
|
||||
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
|
||||
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
|
||||
ctx->num_additional_fused_ops = 4;
|
||||
fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
|
||||
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "RMS_NORM_MUL_ROPE";
|
||||
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
fusion_string = "RMS_NORM_MUL";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
||||
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
||||
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
||||
ctx->num_additional_fused_ops = 2;
|
||||
fusion_string = "ROPE_VIEW_SET_ROWS";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 3;
|
||||
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 3;
|
||||
fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
|
||||
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
||||
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
||||
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
||||
// view of argsort writes to memory
|
||||
ctx->fused_ops_write_mask |= 1 << 1;
|
||||
fusion_string = "TOPK_MOE_LATE_SOFTMAX";
|
||||
}
|
||||
}
|
||||
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
|
||||
|
|
@ -13148,7 +13201,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
|
||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
if (vk_perf_logger_enabled && enqueued) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
|
|
@ -13156,10 +13209,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
} else {
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
|
||||
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
|
||||
}
|
||||
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
|
||||
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
|
||||
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
||||
}
|
||||
|
||||
if (enqueued) {
|
||||
|
|
@ -13200,14 +13252,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
|
||||
// Get the results and pass them to the logger
|
||||
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
|
||||
VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (!ggml_vk_is_empty(cgraph->nodes[i])) {
|
||||
ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod));
|
||||
}
|
||||
VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
|
||||
for (int i = 1; i < ctx->query_idx; i++) {
|
||||
auto node = ctx->query_nodes[i];
|
||||
auto name = ctx->query_fusion_names[i];
|
||||
ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
|
||||
}
|
||||
|
||||
ctx->device->perf_logger->print_timings();
|
||||
ctx->perf_logger->print_timings();
|
||||
}
|
||||
|
||||
if (!ctx->device->support_async) {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ layout (push_constant) uniform parameter
|
|||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||
uint misalign_offsets;
|
||||
uint circular;
|
||||
|
||||
uint lp0; uint rp0;
|
||||
uint lp1; uint rp1;
|
||||
|
|
@ -18,6 +19,10 @@ layout (push_constant) uniform parameter
|
|||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
uint wrap_around(int coord, uint size) {
|
||||
return (uint(coord + int(size))) % size; // add size to avoid issues with negative
|
||||
}
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
|
|
@ -40,10 +45,20 @@ void main() {
|
|||
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
|
||||
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
|
||||
|
||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||
if (p.circular != 0u) {
|
||||
const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00);
|
||||
const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01);
|
||||
const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02);
|
||||
const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03);
|
||||
const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00;
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]);
|
||||
} else {
|
||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||
}
|
||||
|
||||
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4963,6 +4963,18 @@ struct ggml_tensor * ggml_pad(
|
|||
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||
}
|
||||
|
||||
// ggml_pad_circular
|
||||
|
||||
struct ggml_tensor * ggml_pad_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3) {
|
||||
return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_pad_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
@ -4989,6 +5001,7 @@ struct ggml_tensor * ggml_pad_ext(
|
|||
ggml_set_op_params_i32(result, 5, rp2);
|
||||
ggml_set_op_params_i32(result, 6, lp3);
|
||||
ggml_set_op_params_i32(result, 7, rp3);
|
||||
ggml_set_op_params_i32(result, 8, 0); // not circular by default
|
||||
|
||||
|
||||
result->op = GGML_OP_PAD;
|
||||
|
|
@ -4997,6 +5010,25 @@ struct ggml_tensor * ggml_pad_ext(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_pad_ext_circular
|
||||
|
||||
struct ggml_tensor * ggml_pad_ext_circular(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int lp0,
|
||||
int rp0,
|
||||
int lp1,
|
||||
int rp1,
|
||||
int lp2,
|
||||
int rp2,
|
||||
int lp3,
|
||||
int rp3
|
||||
) {
|
||||
struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
|
||||
ggml_set_op_params_i32(result, 8, 1); // circular
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_pad_reflect_1d
|
||||
|
||||
struct ggml_tensor * ggml_pad_reflect_1d(
|
||||
|
|
|
|||
|
|
@ -376,6 +376,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||
"model.layers.{bid}.mlp.router.gate", # afmoe
|
||||
"layers.{bid}.gate", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
|
|
@ -450,6 +451,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj",
|
||||
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w3", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_CHEXP: (
|
||||
|
|
@ -496,6 +498,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w1", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_CHEXP: (
|
||||
|
|
@ -557,6 +560,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||
"layers.{bid}.shared_experts.w2", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_CHEXP: (
|
||||
|
|
@ -924,14 +928,17 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
|
||||
"layers.{bid}.attention.wq_a", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_B: (
|
||||
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
|
||||
"layers.{bid}.attention.wq_b", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: (
|
||||
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
|
||||
"layers.{bid}.attention.wkv_a_with_mqa", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_B: (
|
||||
|
|
@ -940,18 +947,22 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.ATTN_K_B: (
|
||||
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
|
||||
"layers.{bid}.attention.k_b_proj", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_V_B: (
|
||||
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
|
||||
"layers.{bid}.attention.v_b_proj", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||
"layers.{bid}.attention.q_a_norm", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||
"layers.{bid}.attention.kv_a_norm", # mistral-large
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: (
|
||||
|
|
|
|||
|
|
@ -669,7 +669,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
|
||||
std::map<int, std::string> mapped;
|
||||
int blk_id = 0;
|
||||
int pruned_attention_w = 0;
|
||||
|
||||
// make a list of weights
|
||||
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
||||
|
|
@ -677,11 +676,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
for (const auto & it : ml.weights_map) {
|
||||
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
|
||||
if (remapped_name.empty()) {
|
||||
if (it.first.find("attn_v.weight") != std::string::npos ||
|
||||
it.first.find("attn_qkv.weight") != std::string::npos ||
|
||||
it.first.find("attn_kv_b.weight") != std::string::npos) {
|
||||
pruned_attention_w++;
|
||||
}
|
||||
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||
continue;
|
||||
}
|
||||
|
|
@ -706,7 +700,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
});
|
||||
}
|
||||
|
||||
bool is_clip_model = false;
|
||||
for (const auto * it : tensors) {
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
|
||||
|
|
@ -720,30 +713,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
qs.has_output = true;
|
||||
}
|
||||
|
||||
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
|
||||
}
|
||||
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||
{
|
||||
int32_t n_layer_all = model.hparams.n_layer;
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
// now n_layer_all is the number of attention layers in the encoder
|
||||
// for each decoder block, there are 2 attention layers
|
||||
n_layer_all += 2 * model.hparams.dec_n_layer;
|
||||
}
|
||||
|
||||
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
||||
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w);
|
||||
|
||||
GGML_ASSERT_CONTINUE((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
size_t total_size_new = 0;
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -494,6 +494,18 @@ int32_t server_tokens::process_chunk(
|
|||
return 0;
|
||||
}
|
||||
|
||||
server_tokens server_tokens::clone() const {
|
||||
server_tokens res;
|
||||
res.has_mtmd = has_mtmd;
|
||||
res.tokens = tokens;
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||
size_t idx = it->first;
|
||||
const mtmd::input_chunk_ptr & chunk = it->second;
|
||||
res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get()));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
//
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
|
|
@ -745,12 +757,6 @@ json oaicompat_completion_params_parse(const json & body) {
|
|||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
throw std::runtime_error("Only one completion choice is allowed");
|
||||
}
|
||||
|
||||
// Handle "echo" field
|
||||
if (json_value(body, "echo", false)) {
|
||||
throw std::runtime_error("Only no echo is supported");
|
||||
|
|
@ -1049,12 +1055,6 @@ json oaicompat_chat_params_parse(
|
|||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
throw std::invalid_argument("Only one completion choice is allowed");
|
||||
}
|
||||
|
||||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
|
|
|
|||
|
|
@ -215,6 +215,8 @@ public:
|
|||
llama_pos pos,
|
||||
int32_t seq_id,
|
||||
size_t & n_tokens_out) const;
|
||||
|
||||
server_tokens clone() const;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
|
|||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
||||
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
||||
SLOT_STATE_PROCESSING_PROMPT,
|
||||
SLOT_STATE_DONE_PROMPT,
|
||||
SLOT_STATE_GENERATING,
|
||||
|
|
@ -254,6 +255,15 @@ struct server_slot {
|
|||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
// note: a slot can also be either a parent or a child
|
||||
bool is_parent() const {
|
||||
return is_processing() && task->n_children > 0;
|
||||
}
|
||||
|
||||
bool is_child() const {
|
||||
return is_processing() && task->id_parent >= 0;
|
||||
}
|
||||
|
||||
void release() {
|
||||
if (is_processing()) {
|
||||
GGML_ASSERT(task);
|
||||
|
|
@ -383,6 +393,17 @@ struct server_slot {
|
|||
|
||||
return res;
|
||||
}
|
||||
|
||||
void copy_state_to(server_slot & other) const {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1);
|
||||
other.n_decoded = n_decoded;
|
||||
other.n_remaining = n_remaining;
|
||||
other.i_batch = i_batch;
|
||||
other.n_prompt_tokens_cache = n_prompt_tokens_cache;
|
||||
other.n_prompt_tokens_processed = n_prompt_tokens_processed;
|
||||
other.prompt = prompt.clone();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
|
@ -1022,7 +1043,9 @@ struct server_context_impl {
|
|||
|
||||
slot.task = std::make_unique<const server_task>(std::move(task));
|
||||
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
slot.state = slot.is_child()
|
||||
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
|
||||
: SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
|
||||
|
|
@ -1684,6 +1707,12 @@ struct server_context_impl {
|
|||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
if (slot.is_parent() || slot.is_child()) {
|
||||
send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Shift context
|
||||
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
|
||||
|
||||
|
|
@ -2308,6 +2337,26 @@ struct server_context_impl {
|
|||
n_batch = llama_n_batch(ctx);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// may need to copy state to other slots
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
|
||||
std::vector<server_slot *> child_slots;
|
||||
for (auto & other : slots) {
|
||||
if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) {
|
||||
child_slots.push_back(&other);
|
||||
}
|
||||
}
|
||||
|
||||
// we can only proceed if all child slots are having the correct tasks
|
||||
if (child_slots.size() == slot.task->n_children) {
|
||||
// copy state to the child slots
|
||||
for (auto & child : child_slots) {
|
||||
SLT_INF(slot, "copying state to child %d\n", child->id);
|
||||
slot.copy_state_to(*child);
|
||||
child->state = SLOT_STATE_DONE_PROMPT;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// optionally send prompt processing progress
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task->params.stream && slot.task->params.return_progress) {
|
||||
|
|
@ -2593,11 +2642,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
}
|
||||
tasks.reserve(inputs.size());
|
||||
states.reserve(inputs.size());
|
||||
int idx = 0;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.index = idx++;
|
||||
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
|
|
@ -2612,6 +2662,18 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
states.push_back(task.params.oaicompat_chat_syntax);
|
||||
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
for (size_t j = 0; j < task.n_children; j++) {
|
||||
server_task child = task.create_child(
|
||||
task.id,
|
||||
ctx_server.queue_tasks.get_new_id(),
|
||||
idx++);
|
||||
states.push_back(child.params.oaicompat_chat_syntax);
|
||||
tasks.push_back(std::move(child));
|
||||
}
|
||||
}
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
|
|
@ -2638,8 +2700,21 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||
arr.push_back(res->to_json());
|
||||
}
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr.size() == 1 ? arr[0] : arr);
|
||||
GGML_ASSERT(!arr.empty() && "empty results");
|
||||
if (arr.size() == 1) {
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr[0]);
|
||||
} else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) {
|
||||
// if multiple results in OAI format, we need to re-format them
|
||||
json & choices = arr[0]["choices"];
|
||||
for (size_t i = 1; i < arr.size(); i++) {
|
||||
choices.push_back(std::move(arr[i]["choices"][0]));
|
||||
}
|
||||
res->ok(arr[0]);
|
||||
} else {
|
||||
// multi-results, non-OAI compat
|
||||
res->ok(arr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// in streaming mode, the first error must be treated as non-stream response
|
||||
|
|
|
|||
|
|
@ -175,6 +175,7 @@ task_params server_task::params_from_json_cmpl(
|
|||
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
|
|
@ -453,6 +454,10 @@ task_params server_task::params_from_json_cmpl(
|
|||
}
|
||||
}
|
||||
|
||||
if (params.n_cmpl > params_base.n_parallel) {
|
||||
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
@ -664,7 +669,7 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
|||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"index", index},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
};
|
||||
|
||||
|
|
@ -1064,7 +1069,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
|
|||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"index", index},
|
||||
{"delta", delta},
|
||||
},
|
||||
})},
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ struct task_params {
|
|||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
|
||||
int32_t n_cmpl = 1; // number of completions to generate from this prompt
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
|
@ -89,6 +90,10 @@ struct server_task {
|
|||
int id_target = -1;
|
||||
int id_slot = -1;
|
||||
|
||||
// used by parallel sampling (multiple completions from same prompt)
|
||||
size_t n_children = 0; // number of tasks reusing this prompt
|
||||
int id_parent = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
task_params params;
|
||||
server_tokens tokens;
|
||||
|
|
@ -130,6 +135,17 @@ struct server_task {
|
|||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
server_task create_child(int id_parent, int id_child, int idx) const {
|
||||
server_task copy;
|
||||
copy.id = id_child;
|
||||
copy.index = idx;
|
||||
copy.id_parent = id_parent;
|
||||
copy.params = params;
|
||||
copy.type = type;
|
||||
copy.tokens = tokens.clone();
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
|
|
@ -466,6 +482,14 @@ struct server_prompt {
|
|||
int n_tokens() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
server_prompt clone() const {
|
||||
return server_prompt {
|
||||
tokens.clone(),
|
||||
data,
|
||||
checkpoints
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
|
|
|
|||
|
|
@ -477,3 +477,22 @@ def test_return_progress(n_batch, batch_count, reuse_cache):
|
|||
assert last_progress["total"] > 0
|
||||
assert last_progress["processed"] == last_progress["total"]
|
||||
assert total_batch_count == batch_count
|
||||
|
||||
|
||||
def test_chat_completions_multiple_choices():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 8,
|
||||
"n": 2,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["choices"]) == 2
|
||||
for choice in res.body["choices"]:
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex("Suddenly", choice["message"]["content"])
|
||||
assert choice["finish_reason"] == "length"
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import { copyToClipboard, isIMEComposing } from '$lib/utils';
|
||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||
import ChatMessageSystem from './ChatMessageSystem.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
|
|
@ -140,8 +141,7 @@
|
|||
}
|
||||
|
||||
function handleSaveEdit() {
|
||||
if (message.role === 'user') {
|
||||
// For user messages, trim to avoid accidental whitespace
|
||||
if (message.role === 'user' || message.role === 'system') {
|
||||
onEditWithBranching?.(message, editedContent.trim());
|
||||
} else {
|
||||
// For assistant messages, preserve exact content including trailing whitespace
|
||||
|
|
@ -167,7 +167,28 @@
|
|||
}
|
||||
</script>
|
||||
|
||||
{#if message.role === 'user'}
|
||||
{#if message.role === 'system'}
|
||||
<ChatMessageSystem
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
{deletionInfo}
|
||||
{editedContent}
|
||||
{isEditing}
|
||||
{message}
|
||||
onCancelEdit={handleCancelEdit}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onEditKeydown={handleEditKeydown}
|
||||
onEditedContentChange={handleEditedContentChange}
|
||||
{onNavigateToSibling}
|
||||
onSaveEdit={handleSaveEdit}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
/>
|
||||
{:else if message.role === 'user'}
|
||||
<ChatMessageUser
|
||||
bind:textareaElement
|
||||
class={className}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,216 @@
|
|||
<script lang="ts">
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { MarkdownContent } from '$lib/components/app';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
message: DatabaseMessage;
|
||||
isEditing: boolean;
|
||||
editedContent: string;
|
||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||
showDeleteDialog: boolean;
|
||||
deletionInfo: {
|
||||
totalCount: number;
|
||||
userMessages: number;
|
||||
assistantMessages: number;
|
||||
messageTypes: string[];
|
||||
} | null;
|
||||
onCancelEdit: () => void;
|
||||
onSaveEdit: () => void;
|
||||
onEditKeydown: (event: KeyboardEvent) => void;
|
||||
onEditedContentChange: (content: string) => void;
|
||||
onCopy: () => void;
|
||||
onEdit: () => void;
|
||||
onDelete: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onShowDeleteDialogChange: (show: boolean) => void;
|
||||
textareaElement?: HTMLTextAreaElement;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
message,
|
||||
isEditing,
|
||||
editedContent,
|
||||
siblingInfo = null,
|
||||
showDeleteDialog,
|
||||
deletionInfo,
|
||||
onCancelEdit,
|
||||
onSaveEdit,
|
||||
onEditKeydown,
|
||||
onEditedContentChange,
|
||||
onCopy,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onConfirmDelete,
|
||||
onNavigateToSibling,
|
||||
onShowDeleteDialogChange,
|
||||
textareaElement = $bindable()
|
||||
}: Props = $props();
|
||||
|
||||
let isMultiline = $state(false);
|
||||
let messageElement: HTMLElement | undefined = $state();
|
||||
let isExpanded = $state(false);
|
||||
let contentHeight = $state(0);
|
||||
const MAX_HEIGHT = 200; // pixels
|
||||
const currentConfig = config();
|
||||
|
||||
let showExpandButton = $derived(contentHeight > MAX_HEIGHT);
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !message.content.trim()) return;
|
||||
|
||||
if (message.content.includes('\n')) {
|
||||
isMultiline = true;
|
||||
}
|
||||
|
||||
const resizeObserver = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const element = entry.target as HTMLElement;
|
||||
const estimatedSingleLineHeight = 24;
|
||||
|
||||
isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5;
|
||||
contentHeight = element.scrollHeight;
|
||||
}
|
||||
});
|
||||
|
||||
resizeObserver.observe(messageElement);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
});
|
||||
|
||||
function toggleExpand() {
|
||||
isExpanded = !isExpanded;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div
|
||||
aria-label="System message with actions"
|
||||
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
role="group"
|
||||
>
|
||||
{#if isEditing}
|
||||
<div class="w-full max-w-[80%]">
|
||||
<textarea
|
||||
bind:this={textareaElement}
|
||||
bind:value={editedContent}
|
||||
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||
onkeydown={onEditKeydown}
|
||||
oninput={(e) => onEditedContentChange(e.currentTarget.value)}
|
||||
placeholder="Edit system message..."
|
||||
></textarea>
|
||||
|
||||
<div class="mt-2 flex justify-end gap-2">
|
||||
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
|
||||
<X class="mr-1 h-3 w-3" />
|
||||
Cancel
|
||||
</Button>
|
||||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
Send
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
{#if message.content.trim()}
|
||||
<div class="relative max-w-[80%]">
|
||||
<button
|
||||
class="group/expand w-full text-left {!isExpanded && showExpandButton
|
||||
? 'cursor-pointer'
|
||||
: 'cursor-auto'}"
|
||||
onclick={showExpandButton && !isExpanded ? toggleExpand : undefined}
|
||||
type="button"
|
||||
>
|
||||
<Card
|
||||
class="rounded-[1.125rem] !border-2 !border-dashed !border-border/50 bg-muted px-3.75 py-1.5 data-[multiline]:py-2.5"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
style="border: 2px dashed hsl(var(--border));"
|
||||
>
|
||||
<div
|
||||
class="relative overflow-hidden transition-all duration-300 {isExpanded
|
||||
? 'cursor-text select-text'
|
||||
: 'select-none'}"
|
||||
style={!isExpanded && showExpandButton
|
||||
? `max-height: ${MAX_HEIGHT}px;`
|
||||
: 'max-height: none;'}
|
||||
>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
<div bind:this={messageElement} class="text-md {isExpanded ? 'cursor-text' : ''}">
|
||||
<MarkdownContent class="markdown-system-content" content={message.content} />
|
||||
</div>
|
||||
{:else}
|
||||
<span
|
||||
bind:this={messageElement}
|
||||
class="text-md whitespace-pre-wrap {isExpanded ? 'cursor-text' : ''}"
|
||||
>
|
||||
{message.content}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
{#if !isExpanded && showExpandButton}
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-0 left-0 h-48 bg-gradient-to-t from-muted to-transparent"
|
||||
></div>
|
||||
<div
|
||||
class="pointer-events-none absolute right-0 bottom-4 left-0 flex justify-center opacity-0 transition-opacity group-hover/expand:opacity-100"
|
||||
>
|
||||
<Button
|
||||
class="rounded-full px-4 py-1.5 text-xs shadow-md"
|
||||
size="sm"
|
||||
variant="outline"
|
||||
>
|
||||
Show full system message
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isExpanded && showExpandButton}
|
||||
<div class="mb-2 flex justify-center">
|
||||
<Button
|
||||
class="rounded-full px-4 py-1.5 text-xs"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
toggleExpand();
|
||||
}}
|
||||
size="sm"
|
||||
variant="outline"
|
||||
>
|
||||
Collapse System Message
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</Card>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if message.timestamp}
|
||||
<div class="max-w-[80%]">
|
||||
<ChatMessageActions
|
||||
actionsPosition="right"
|
||||
{deletionInfo}
|
||||
justify="end"
|
||||
{onConfirmDelete}
|
||||
{onCopy}
|
||||
{onDelete}
|
||||
{onEdit}
|
||||
{onNavigateToSibling}
|
||||
{onShowDeleteDialogChange}
|
||||
{siblingInfo}
|
||||
{showDeleteDialog}
|
||||
role="user"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
|
|
@ -145,7 +145,7 @@
|
|||
|
||||
{#if message.content.trim()}
|
||||
<Card
|
||||
class="max-w-[80%] rounded-[1.125rem] bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||
class="max-w-[80%] rounded-[1.125rem] border-none bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import { ChatMessage } from '$lib/components/app';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { getMessageSiblings } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -13,6 +14,7 @@
|
|||
let { class: className, messages = [], onUserAction }: Props = $props();
|
||||
|
||||
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
||||
const currentConfig = config();
|
||||
|
||||
function refreshAllMessages() {
|
||||
const conversation = activeConversation();
|
||||
|
|
@ -40,7 +42,12 @@
|
|||
return [];
|
||||
}
|
||||
|
||||
return messages.map((message) => {
|
||||
// Filter out system messages if showSystemMessage is false
|
||||
const filteredMessages = currentConfig.showSystemMessage
|
||||
? messages
|
||||
: messages.filter((msg) => msg.type !== 'system');
|
||||
|
||||
return filteredMessages.map((message) => {
|
||||
const siblingInfo = getMessageSiblings(allConversationMessages, message.id);
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -36,12 +36,6 @@
|
|||
title: 'General',
|
||||
icon: Settings,
|
||||
fields: [
|
||||
{ key: 'apiKey', label: 'API Key', type: 'input' },
|
||||
{
|
||||
key: 'systemMessage',
|
||||
label: 'System Message (will be disabled if left empty)',
|
||||
type: 'textarea'
|
||||
},
|
||||
{
|
||||
key: 'theme',
|
||||
label: 'Theme',
|
||||
|
|
@ -52,6 +46,12 @@
|
|||
{ value: 'dark', label: 'Dark', icon: Moon }
|
||||
]
|
||||
},
|
||||
{ key: 'apiKey', label: 'API Key', type: 'input' },
|
||||
{
|
||||
key: 'systemMessage',
|
||||
label: 'System Message',
|
||||
type: 'textarea'
|
||||
},
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
label: 'Paste long text to file length',
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@
|
|||
</div>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
{@html field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'textarea'}
|
||||
|
|
@ -112,13 +112,28 @@
|
|||
value={String(localConfig[field.key] ?? '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="min-h-[100px] w-full md:max-w-2xl"
|
||||
class="min-h-[10rem] w-full md:max-w-2xl"
|
||||
/>
|
||||
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
|
||||
{#if field.key === 'systemMessage'}
|
||||
<div class="mt-3 flex items-center gap-2">
|
||||
<Checkbox
|
||||
id="showSystemMessage"
|
||||
checked={Boolean(localConfig.showSystemMessage ?? true)}
|
||||
onCheckedChange={(checked) => onConfigChange('showSystemMessage', Boolean(checked))}
|
||||
/>
|
||||
|
||||
<Label for="showSystemMessage" class="cursor-pointer text-sm font-normal">
|
||||
Show system message in conversations
|
||||
</Label>
|
||||
</div>
|
||||
{/if}
|
||||
{:else if field.type === 'select'}
|
||||
{@const selectedOption = field.options?.find(
|
||||
(opt: { value: string; label: string; icon?: Component }) =>
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import ChatSidebarActions from './ChatSidebarActions.svelte';
|
||||
|
||||
const sidebar = Sidebar.useSidebar();
|
||||
|
|
@ -98,6 +99,10 @@
|
|||
|
||||
await goto(`#/chat/${id}`);
|
||||
}
|
||||
|
||||
function handleStopGeneration(id: string) {
|
||||
chatStore.stopGenerationForChat(id);
|
||||
}
|
||||
</script>
|
||||
|
||||
<ScrollArea class="h-[100vh]">
|
||||
|
|
@ -132,6 +137,7 @@
|
|||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
<script lang="ts">
|
||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2 } from '@lucide/svelte';
|
||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2, Square } from '@lucide/svelte';
|
||||
import { ActionDropdown } from '$lib/components/app';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
|
|
@ -12,6 +13,7 @@
|
|||
onDelete?: (id: string) => void;
|
||||
onEdit?: (id: string) => void;
|
||||
onSelect?: (id: string) => void;
|
||||
onStop?: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
|
|
@ -20,6 +22,7 @@
|
|||
onDelete,
|
||||
onEdit,
|
||||
onSelect,
|
||||
onStop,
|
||||
isActive = false
|
||||
}: Props = $props();
|
||||
|
||||
|
|
@ -38,8 +41,14 @@
|
|||
onDelete?.(conversation.id);
|
||||
}
|
||||
|
||||
function handleStop(event: Event) {
|
||||
event.stopPropagation();
|
||||
onStop?.(conversation.id);
|
||||
}
|
||||
|
||||
function handleGlobalEditEvent(event: Event) {
|
||||
const customEvent = event as CustomEvent<{ conversationId: string }>;
|
||||
|
||||
if (customEvent.detail.conversationId === conversation.id && isActive) {
|
||||
handleEdit(event);
|
||||
}
|
||||
|
|
@ -88,8 +97,28 @@
|
|||
>
|
||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||
{#if isLoading}
|
||||
<Loader2 class="h-3.5 w-3.5 shrink-0 animate-spin text-muted-foreground" />
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<div
|
||||
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={handleStop}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||
</div>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>Stop generation</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
|
||||
|
|
@ -147,5 +176,25 @@
|
|||
opacity: 1 !important;
|
||||
}
|
||||
}
|
||||
|
||||
.stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: none;
|
||||
}
|
||||
|
||||
:global(.loading-icon) {
|
||||
display: block;
|
||||
}
|
||||
}
|
||||
|
||||
&:is(:hover) .stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
:global(.loading-icon) {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -19,8 +19,10 @@ export { default as ChatMessage } from './chat/ChatMessages/ChatMessage.svelte';
|
|||
export { default as ChatMessageActions } from './chat/ChatMessages/ChatMessageActions.svelte';
|
||||
export { default as ChatMessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
export { default as ChatMessageStatistics } from './chat/ChatMessages/ChatMessageStatistics.svelte';
|
||||
export { default as ChatMessageSystem } from './chat/ChatMessages/ChatMessageSystem.svelte';
|
||||
export { default as ChatMessageThinkingBlock } from './chat/ChatMessages/ChatMessageThinkingBlock.svelte';
|
||||
export { default as ChatMessages } from './chat/ChatMessages/ChatMessages.svelte';
|
||||
export { default as MessageBranchingControls } from './chat/ChatMessages/ChatMessageBranchingControls.svelte';
|
||||
|
||||
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
||||
export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.svelte';
|
||||
|
|
|
|||
|
|
@ -337,19 +337,23 @@
|
|||
line-height: 1.75;
|
||||
}
|
||||
|
||||
div :global(:is(h1, h2, h3, h4, h5, h6):first-child) {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* Headers with consistent spacing */
|
||||
div :global(h1) {
|
||||
font-size: 1.875rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
line-height: 1.2;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
}
|
||||
|
||||
div :global(h2) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
line-height: 1.3;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
div :global(h3) {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
|
||||
apiKey: '',
|
||||
systemMessage: '',
|
||||
showSystemMessage: true,
|
||||
theme: 'system',
|
||||
showThoughtInProgress: false,
|
||||
showToolCalls: false,
|
||||
|
|
@ -42,8 +43,9 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
};
|
||||
|
||||
export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||
apiKey: 'Set the API Key if you are using <code>--api-key</code> option for the server.',
|
||||
systemMessage: 'The starting message that defines how model should behave.',
|
||||
showSystemMessage: 'Display the system message at the top of each conversation.',
|
||||
theme:
|
||||
'Choose the color theme for the interface. You can choose between System (follows your device settings), Light, or Dark.',
|
||||
pasteLongTextToFileLen:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,6 @@ export class ChatService {
|
|||
custom,
|
||||
timings_per_token,
|
||||
// Config options
|
||||
systemMessage,
|
||||
disableReasoningFormat
|
||||
} = options;
|
||||
|
||||
|
|
@ -103,6 +102,7 @@ export class ChatService {
|
|||
}
|
||||
})
|
||||
.filter((msg) => {
|
||||
// Filter out empty system messages
|
||||
if (msg.role === 'system') {
|
||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||
|
||||
|
|
@ -112,10 +112,8 @@ export class ChatService {
|
|||
return true;
|
||||
});
|
||||
|
||||
const processedMessages = ChatService.injectSystemMessage(normalizedMessages, systemMessage);
|
||||
|
||||
const requestBody: ApiChatCompletionRequest = {
|
||||
messages: processedMessages.map((msg: ApiChatMessageData) => ({
|
||||
messages: normalizedMessages.map((msg: ApiChatMessageData) => ({
|
||||
role: msg.role,
|
||||
content: msg.content
|
||||
})),
|
||||
|
|
@ -681,46 +679,6 @@ export class ChatService {
|
|||
// Utilities
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Injects a system message at the beginning of the conversation if provided.
|
||||
* Checks for existing system messages to avoid duplication.
|
||||
*
|
||||
* @param messages - Array of chat messages to process
|
||||
* @param systemMessage - Optional system message to inject
|
||||
* @returns Array of messages with system message injected at the beginning if provided
|
||||
* @private
|
||||
*/
|
||||
private static injectSystemMessage(
|
||||
messages: ApiChatMessageData[],
|
||||
systemMessage?: string
|
||||
): ApiChatMessageData[] {
|
||||
const trimmedSystemMessage = systemMessage?.trim();
|
||||
|
||||
if (!trimmedSystemMessage) {
|
||||
return messages;
|
||||
}
|
||||
|
||||
if (messages.length > 0 && messages[0].role === 'system') {
|
||||
if (messages[0].content !== trimmedSystemMessage) {
|
||||
const updatedMessages = [...messages];
|
||||
updatedMessages[0] = {
|
||||
role: 'system',
|
||||
content: trimmedSystemMessage
|
||||
};
|
||||
return updatedMessages;
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
const systemMsg: ApiChatMessageData = {
|
||||
role: 'system',
|
||||
content: trimmedSystemMessage
|
||||
};
|
||||
|
||||
return [systemMsg, ...messages];
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses error response and creates appropriate error with context information
|
||||
* @param response - HTTP response object
|
||||
|
|
|
|||
|
|
@ -166,6 +166,49 @@ export class DatabaseService {
|
|||
return rootMessage.id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a system prompt message for a conversation.
|
||||
*
|
||||
* @param convId - Conversation ID
|
||||
* @param systemPrompt - The system prompt content (must be non-empty)
|
||||
* @param parentId - Parent message ID (typically the root message)
|
||||
* @returns The created system message
|
||||
* @throws Error if systemPrompt is empty
|
||||
*/
|
||||
static async createSystemMessage(
|
||||
convId: string,
|
||||
systemPrompt: string,
|
||||
parentId: string
|
||||
): Promise<DatabaseMessage> {
|
||||
const trimmedPrompt = systemPrompt.trim();
|
||||
if (!trimmedPrompt) {
|
||||
throw new Error('Cannot create system message with empty content');
|
||||
}
|
||||
|
||||
const systemMessage: DatabaseMessage = {
|
||||
id: uuid(),
|
||||
convId,
|
||||
type: 'system',
|
||||
timestamp: Date.now(),
|
||||
role: 'system',
|
||||
content: trimmedPrompt,
|
||||
parent: parentId,
|
||||
thinking: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
await db.messages.add(systemMessage);
|
||||
|
||||
const parentMessage = await db.messages.get(parentId);
|
||||
if (parentMessage) {
|
||||
await db.messages.update(parentId, {
|
||||
children: [...parentMessage.children, systemMessage.id]
|
||||
});
|
||||
}
|
||||
|
||||
return systemMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a conversation and all its messages.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -2,7 +2,11 @@ import { DatabaseService, ChatService } from '$lib/services';
|
|||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { contextSize, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { selectedModelName, modelsStore } from '$lib/stores/models.svelte';
|
||||
import {
|
||||
selectedModelName,
|
||||
modelsStore,
|
||||
selectedModelContextSize
|
||||
} from '$lib/stores/models.svelte';
|
||||
import {
|
||||
normalizeModelName,
|
||||
filterByLeafNodeId,
|
||||
|
|
@ -261,6 +265,13 @@ class ChatStore {
|
|||
return activeState.contextTotal;
|
||||
}
|
||||
|
||||
if (isRouterMode()) {
|
||||
const modelContextSize = selectedModelContextSize();
|
||||
if (modelContextSize && modelContextSize > 0) {
|
||||
return modelContextSize;
|
||||
}
|
||||
}
|
||||
|
||||
const propsContextSize = contextSize();
|
||||
if (propsContextSize && propsContextSize > 0) {
|
||||
return propsContextSize;
|
||||
|
|
@ -458,6 +469,14 @@ class ChatStore {
|
|||
onError?: (error: Error) => void,
|
||||
modelOverride?: string | null
|
||||
): Promise<void> {
|
||||
// Ensure model props are cached before streaming (for correct n_ctx in processing info)
|
||||
if (isRouterMode()) {
|
||||
const modelName = modelOverride || selectedModelName();
|
||||
if (modelName && !modelsStore.getModelProps(modelName)) {
|
||||
await modelsStore.fetchModelProps(modelName);
|
||||
}
|
||||
}
|
||||
|
||||
let streamedContent = '';
|
||||
let streamedReasoningContent = '';
|
||||
let streamedToolCallContent = '';
|
||||
|
|
@ -624,6 +643,22 @@ class ChatStore {
|
|||
this.clearChatStreaming(currentConv.id);
|
||||
|
||||
try {
|
||||
if (isNewConversation) {
|
||||
const rootId = await DatabaseService.createRootMessage(currentConv.id);
|
||||
const currentConfig = config();
|
||||
const systemPrompt = currentConfig.systemMessage?.toString().trim();
|
||||
|
||||
if (systemPrompt) {
|
||||
const systemMessage = await DatabaseService.createSystemMessage(
|
||||
currentConv.id,
|
||||
systemPrompt,
|
||||
rootId
|
||||
);
|
||||
|
||||
conversationsStore.addMessageToActive(systemMessage);
|
||||
}
|
||||
}
|
||||
|
||||
const userMessage = await this.addMessage('user', content, 'text', '-1', extras);
|
||||
if (!userMessage) throw new Error('Failed to add user message');
|
||||
if (isNewConversation && content)
|
||||
|
|
@ -666,13 +701,17 @@ class ChatStore {
|
|||
|
||||
if (!activeConv) return;
|
||||
|
||||
await this.savePartialResponseIfNeeded(activeConv.id);
|
||||
await this.stopGenerationForChat(activeConv.id);
|
||||
}
|
||||
|
||||
async stopGenerationForChat(convId: string): Promise<void> {
|
||||
await this.savePartialResponseIfNeeded(convId);
|
||||
|
||||
this.stopStreaming();
|
||||
this.abortRequest(activeConv.id);
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
this.clearProcessingState(activeConv.id);
|
||||
this.abortRequest(convId);
|
||||
this.setChatLoading(convId, false);
|
||||
this.clearChatStreaming(convId);
|
||||
this.clearProcessingState(convId);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -999,14 +1038,20 @@ class ChatStore {
|
|||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isLoading) return;
|
||||
|
||||
const result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
let result = this.getMessageByIdWithRole(messageId, 'user');
|
||||
|
||||
if (!result) {
|
||||
result = this.getMessageByIdWithRole(messageId, 'system');
|
||||
}
|
||||
|
||||
if (!result) return;
|
||||
const { message: msg } = result;
|
||||
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
const isFirstUserMessage = rootMessage && msg.parent === rootMessage.id;
|
||||
const isFirstUserMessage =
|
||||
msg.role === 'user' && rootMessage && msg.parent === rootMessage.id;
|
||||
|
||||
const parentId = msg.parent || rootMessage?.id;
|
||||
if (!parentId) return;
|
||||
|
|
@ -1037,7 +1082,10 @@ class ChatStore {
|
|||
);
|
||||
}
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
await this.generateResponseForMessage(newMessage.id);
|
||||
|
||||
if (msg.role === 'user') {
|
||||
await this.generateResponseForMessage(newMessage.id);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to edit message with branching:', error);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -158,6 +158,22 @@ class ModelsStore {
|
|||
return this.modelPropsCache.get(modelId) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context size (n_ctx) for a specific model from cached props
|
||||
*/
|
||||
getModelContextSize(modelId: string): number | null {
|
||||
const props = this.modelPropsCache.get(modelId);
|
||||
return props?.default_generation_settings?.n_ctx ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context size for the currently selected model or null if no model is selected
|
||||
*/
|
||||
get selectedModelContextSize(): number | null {
|
||||
if (!this.selectedModelName) return null;
|
||||
return this.getModelContextSize(this.selectedModelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if props are being fetched for a model
|
||||
*/
|
||||
|
|
@ -579,3 +595,4 @@ export const loadedModelIds = () => modelsStore.loadedModelIds;
|
|||
export const loadingModelIds = () => modelsStore.loadingModelIds;
|
||||
export const propsCacheVersion = () => modelsStore.propsCacheVersion;
|
||||
export const singleModelName = () => modelsStore.singleModelName;
|
||||
export const selectedModelContextSize = () => modelsStore.selectedModelContextSize;
|
||||
|
|
|
|||
2
tools/server/webui/src/lib/types/chat.d.ts
vendored
2
tools/server/webui/src/lib/types/chat.d.ts
vendored
|
|
@ -1,4 +1,4 @@
|
|||
export type ChatMessageType = 'root' | 'text' | 'think';
|
||||
export type ChatMessageType = 'root' | 'text' | 'think' | 'system';
|
||||
export type ChatRole = 'user' | 'assistant' | 'system';
|
||||
|
||||
export interface ChatUploadedFile {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue