mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .ecrc # CMakePresets.json # ci/run.sh # docs/backend/SYCL.md # ggml/src/CMakeLists.txt # src/llama.cpp # tests/test-backend-ops.cpp # tests/test-sampling.cpp
This commit is contained in:
commit
b2c1ff7a13
30 changed files with 7666 additions and 6889 deletions
|
@ -78,6 +78,41 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Environment variable utils
|
||||||
|
//
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
|
||||||
|
get_env(std::string name, T & target) {
|
||||||
|
char * value = std::getenv(name.c_str());
|
||||||
|
target = value ? std::string(value) : target;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
|
||||||
|
get_env(std::string name, T & target) {
|
||||||
|
char * value = std::getenv(name.c_str());
|
||||||
|
target = value ? std::stoi(value) : target;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static typename std::enable_if<std::is_floating_point<T>::value, void>::type
|
||||||
|
get_env(std::string name, T & target) {
|
||||||
|
char * value = std::getenv(name.c_str());
|
||||||
|
target = value ? std::stof(value) : target;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static typename std::enable_if<std::is_same<T, bool>::value, void>::type
|
||||||
|
get_env(std::string name, T & target) {
|
||||||
|
char * value = std::getenv(name.c_str());
|
||||||
|
if (value) {
|
||||||
|
std::string val(value);
|
||||||
|
target = val == "1" || val == "true";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CPU utils
|
// CPU utils
|
||||||
//
|
//
|
||||||
|
@ -221,12 +256,6 @@ int32_t cpu_get_num_math() {
|
||||||
// CLI argument parsing
|
// CLI argument parsing
|
||||||
//
|
//
|
||||||
|
|
||||||
void gpt_params_handle_hf_token(gpt_params & params) {
|
|
||||||
if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
|
|
||||||
params.hf_token = std::getenv("HF_TOKEN");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void gpt_params_handle_model_default(gpt_params & params) {
|
void gpt_params_handle_model_default(gpt_params & params) {
|
||||||
if (!params.hf_repo.empty()) {
|
if (!params.hf_repo.empty()) {
|
||||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||||
|
@ -274,7 +303,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
|
|
||||||
gpt_params_handle_model_default(params);
|
gpt_params_handle_model_default(params);
|
||||||
|
|
||||||
gpt_params_handle_hf_token(params);
|
if (params.hf_token.empty()) {
|
||||||
|
get_env("HF_TOKEN", params.hf_token);
|
||||||
|
}
|
||||||
|
|
||||||
if (params.escape) {
|
if (params.escape) {
|
||||||
string_process_escapes(params.prompt);
|
string_process_escapes(params.prompt);
|
||||||
|
@ -294,6 +325,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void gpt_params_parse_from_env(gpt_params & params) {
|
||||||
|
// we only care about server-related params for now
|
||||||
|
get_env("LLAMA_ARG_MODEL", params.model);
|
||||||
|
get_env("LLAMA_ARG_THREADS", params.n_threads);
|
||||||
|
get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx);
|
||||||
|
get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel);
|
||||||
|
get_env("LLAMA_ARG_BATCH", params.n_batch);
|
||||||
|
get_env("LLAMA_ARG_UBATCH", params.n_ubatch);
|
||||||
|
get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers);
|
||||||
|
get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http);
|
||||||
|
get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template);
|
||||||
|
get_env("LLAMA_ARG_N_PREDICT", params.n_predict);
|
||||||
|
get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics);
|
||||||
|
get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots);
|
||||||
|
get_env("LLAMA_ARG_EMBEDDINGS", params.embedding);
|
||||||
|
get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn);
|
||||||
|
get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold);
|
||||||
|
}
|
||||||
|
|
||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
const auto params_org = params; // the example can modify the default params
|
const auto params_org = params; // the example can modify the default params
|
||||||
|
|
||||||
|
@ -852,7 +902,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") {
|
if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--n-gpu-layers-draft") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.n_gpu_layers_draft = std::stoi(argv[i]);
|
params.n_gpu_layers_draft = std::stoi(argv[i]);
|
||||||
if (!llama_supports_gpu_offload()) {
|
if (!llama_supports_gpu_offload()) {
|
||||||
|
@ -1812,13 +1862,19 @@ std::string string_get_sortable_timestamp() {
|
||||||
|
|
||||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||||
if (search.empty()) {
|
if (search.empty()) {
|
||||||
return; // Avoid infinite loop if 'search' is an empty string
|
return;
|
||||||
}
|
}
|
||||||
|
std::string builder;
|
||||||
|
builder.reserve(s.length());
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
while ((pos = s.find(search, pos)) != std::string::npos) {
|
size_t last_pos = 0;
|
||||||
s.replace(pos, search.length(), replace);
|
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
||||||
pos += replace.length();
|
builder.append(s, last_pos, pos - last_pos);
|
||||||
|
builder.append(replace);
|
||||||
|
last_pos = pos + search.length();
|
||||||
}
|
}
|
||||||
|
builder.append(s, last_pos, std::string::npos);
|
||||||
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
void string_process_escapes(std::string & input) {
|
void string_process_escapes(std::string & input) {
|
||||||
|
|
|
@ -291,7 +291,7 @@ struct gpt_params {
|
||||||
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
|
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
|
||||||
};
|
};
|
||||||
|
|
||||||
void gpt_params_handle_hf_token(gpt_params & params);
|
void gpt_params_parse_from_env(gpt_params & params);
|
||||||
void gpt_params_handle_model_default(gpt_params & params);
|
void gpt_params_handle_model_default(gpt_params & params);
|
||||||
|
|
||||||
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
|
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
|
||||||
|
|
2990
common/stb_image.h
2990
common/stb_image.h
File diff suppressed because it is too large
Load diff
|
@ -63,6 +63,7 @@ class Model:
|
||||||
model_name: str | None
|
model_name: str | None
|
||||||
metadata_override: Path | None
|
metadata_override: Path | None
|
||||||
dir_model_card: Path
|
dir_model_card: Path
|
||||||
|
is_lora: bool
|
||||||
|
|
||||||
# subclasses should define this!
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
@ -70,7 +71,7 @@ class Model:
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
||||||
use_temp_file: bool = False, eager: bool = False,
|
use_temp_file: bool = False, eager: bool = False,
|
||||||
metadata_override: Path | None = None, model_name: str | None = None,
|
metadata_override: Path | None = None, model_name: str | None = None,
|
||||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False):
|
||||||
if type(self) is Model:
|
if type(self) is Model:
|
||||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||||
|
|
||||||
|
@ -92,6 +93,7 @@ class Model:
|
||||||
self.metadata_override = metadata_override
|
self.metadata_override = metadata_override
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||||
|
self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
|
||||||
|
|
||||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||||
|
@ -1570,7 +1572,7 @@ class LlamaModel(Model):
|
||||||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||||
base = self.hparams.get("rope_theta", 10000.0)
|
base = self.hparams.get("rope_theta", 10000.0)
|
||||||
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||||
|
|
||||||
factor = rope_scaling.get("factor", 8.0)
|
factor = rope_scaling.get("factor", 8.0)
|
||||||
|
@ -1593,6 +1595,7 @@ class LlamaModel(Model):
|
||||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||||
|
|
||||||
|
if not self.is_lora:
|
||||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||||
|
|
||||||
super().prepare_tensors()
|
super().prepare_tensors()
|
||||||
|
@ -2140,6 +2143,7 @@ class Phi3MiniModel(Model):
|
||||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||||
|
|
||||||
|
if not self.is_lora:
|
||||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||||
|
|
||||||
|
@ -3816,7 +3820,7 @@ class ExaoneModel(Model):
|
||||||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||||
base = self.hparams.get("rope_theta", 10000.0)
|
base = self.hparams.get("rope_theta", 10000.0)
|
||||||
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||||
|
|
||||||
factor = rope_scaling.get("factor", 8.0)
|
factor = rope_scaling.get("factor", 8.0)
|
||||||
|
@ -3839,6 +3843,7 @@ class ExaoneModel(Model):
|
||||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||||
|
|
||||||
|
if not self.is_lora:
|
||||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||||
|
|
||||||
super().prepare_tensors()
|
super().prepare_tensors()
|
||||||
|
|
|
@ -386,6 +386,7 @@ if __name__ == '__main__':
|
||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
dir_lora_model=dir_lora,
|
dir_lora_model=dir_lora,
|
||||||
lora_alpha=alpha,
|
lora_alpha=alpha,
|
||||||
|
is_lora=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Exporting model...")
|
logger.info("Exporting model...")
|
||||||
|
|
|
@ -219,13 +219,19 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
|
||||||
|
|
||||||
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||||
if (search.empty()) {
|
if (search.empty()) {
|
||||||
return; // Avoid infinite loop if 'search' is an empty string
|
return;
|
||||||
}
|
}
|
||||||
|
std::string builder;
|
||||||
|
builder.reserve(s.length());
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
while ((pos = s.find(search, pos)) != std::string::npos) {
|
size_t last_pos = 0;
|
||||||
s.replace(pos, search.length(), replace);
|
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
||||||
pos += replace.length();
|
builder.append(s, last_pos, pos - last_pos);
|
||||||
|
builder.append(replace);
|
||||||
|
last_pos = pos + search.length();
|
||||||
}
|
}
|
||||||
|
builder.append(s, last_pos, std::string::npos);
|
||||||
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
||||||
|
|
|
@ -105,7 +105,7 @@ static void usage(const char * executable) {
|
||||||
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
||||||
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
|
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
|
||||||
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
||||||
printf(" --keep-split: will generate quatized model in the same shards as input");
|
printf(" --keep-split: will generate quantized model in the same shards as input\n");
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
||||||
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
|
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
|
||||||
|
|
|
@ -247,6 +247,25 @@ logging:
|
||||||
--log-append Don't truncate the old log file.
|
--log-append Don't truncate the old log file.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Available environment variables (if specified, these variables will override parameters specified in arguments):
|
||||||
|
|
||||||
|
- `LLAMA_CACHE` (cache directory, used by `--hf-repo`)
|
||||||
|
- `HF_TOKEN` (Hugging Face access token, used when accessing a gated model with `--hf-repo`)
|
||||||
|
- `LLAMA_ARG_MODEL`
|
||||||
|
- `LLAMA_ARG_THREADS`
|
||||||
|
- `LLAMA_ARG_CTX_SIZE`
|
||||||
|
- `LLAMA_ARG_N_PARALLEL`
|
||||||
|
- `LLAMA_ARG_BATCH`
|
||||||
|
- `LLAMA_ARG_UBATCH`
|
||||||
|
- `LLAMA_ARG_N_GPU_LAYERS`
|
||||||
|
- `LLAMA_ARG_THREADS_HTTP`
|
||||||
|
- `LLAMA_ARG_CHAT_TEMPLATE`
|
||||||
|
- `LLAMA_ARG_N_PREDICT`
|
||||||
|
- `LLAMA_ARG_ENDPOINT_METRICS`
|
||||||
|
- `LLAMA_ARG_ENDPOINT_SLOTS`
|
||||||
|
- `LLAMA_ARG_EMBEDDINGS`
|
||||||
|
- `LLAMA_ARG_FLASH_ATTN`
|
||||||
|
- `LLAMA_ARG_DEFRAG_THOLD`
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -2508,6 +2508,9 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse arguments from environment variables
|
||||||
|
gpt_params_parse_from_env(params);
|
||||||
|
|
||||||
// TODO: not great to use extern vars
|
// TODO: not great to use extern vars
|
||||||
server_log_json = params.log_json;
|
server_log_json = params.log_json;
|
||||||
server_verbose = params.verbosity > 0;
|
server_verbose = params.verbosity > 0;
|
||||||
|
|
|
@ -1766,7 +1766,8 @@ extern "C" {
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias);
|
float max_bias,
|
||||||
|
float logit_softcap);
|
||||||
|
|
||||||
GGML_API void ggml_flash_attn_ext_set_prec(
|
GGML_API void ggml_flash_attn_ext_set_prec(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -1783,10 +1784,8 @@ extern "C" {
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * sx,
|
||||||
struct ggml_tensor * x,
|
struct ggml_tensor * c);
|
||||||
struct ggml_tensor * c,
|
|
||||||
struct ggml_tensor * sq);
|
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -1795,8 +1794,7 @@ extern "C" {
|
||||||
struct ggml_tensor * dt,
|
struct ggml_tensor * dt,
|
||||||
struct ggml_tensor * A,
|
struct ggml_tensor * A,
|
||||||
struct ggml_tensor * B,
|
struct ggml_tensor * B,
|
||||||
struct ggml_tensor * C,
|
struct ggml_tensor * C);
|
||||||
struct ggml_tensor * sq);
|
|
||||||
|
|
||||||
// partition into non-overlapping windows with padding if needed
|
// partition into non-overlapping windows with padding if needed
|
||||||
// example:
|
// example:
|
||||||
|
|
|
@ -337,34 +337,19 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
if (!quant_weights) {
|
UNUSED(quant_weights);
|
||||||
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
|
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
assert(false);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
if (!quant_weights) {
|
UNUSED(quant_weights);
|
||||||
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
|
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
assert(false);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
if (!quant_weights) {
|
UNUSED(quant_weights);
|
||||||
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
|
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
assert(false);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
|
|
|
@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -659,9 +660,15 @@ void launch_fattn(
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
|
float logit_softcap = 0.0f;
|
||||||
|
|
||||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
scale /= logit_softcap;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_head = Q->ne[2];
|
const uint32_t n_head = Q->ne[2];
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
@ -675,7 +682,7 @@ void launch_fattn(
|
||||||
V_data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, n_head_log2,
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_AVAILABLE
|
#ifdef FP16_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
|
@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
half sum;
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
sum = logit_softcap * tanhf(tmp.x + tmp.y);
|
||||||
|
} else {
|
||||||
|
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
}
|
||||||
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
||||||
|
@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks>
|
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64: {
|
case 64: {
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
|
@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] <= 16) {
|
if (Q->ne[1] <= 16) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
if (Q->ne[1] <= 32) {
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
|
@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
}
|
||||||
|
|
||||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks>
|
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||||
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64: {
|
case 64: {
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
|
@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] <= 16) {
|
if (Q->ne[1] <= 16) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
if (Q->ne[1] <= 32) {
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_AVAILABLE
|
#ifdef FP16_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||||
|
@ -190,6 +197,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||||
sum = warp_reduce_sum(sum);
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum = logit_softcap*tanhf(sum);
|
||||||
|
}
|
||||||
|
|
||||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
if (ncols == 1) {
|
if (ncols == 1) {
|
||||||
|
@ -286,10 +298,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = D != 128;
|
constexpr bool need_f16_K = D != 128;
|
||||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||||
|
@ -297,48 +309,81 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
|
||||||
GGML_ASSERT(K->type == type_K);
|
GGML_ASSERT(K->type == type_K);
|
||||||
GGML_ASSERT(V->type == type_V);
|
GGML_ASSERT(V->type == type_V);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 2) {
|
if (Q->ne[1] == 2) {
|
||||||
constexpr int cols_per_block = 2;
|
constexpr int cols_per_block = 2;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 4) {
|
if (Q->ne[1] <= 4) {
|
||||||
constexpr int cols_per_block = 4;
|
constexpr int cols_per_block = 4;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8) {
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
||||||
|
@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||||
sum = warp_reduce_sum(sum);
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum = logit_softcap*tanhf(sum);
|
||||||
|
}
|
||||||
|
|
||||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||||
|
@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = D != 128;
|
constexpr bool need_f16_K = D != 128;
|
||||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||||
|
@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * KQV = dst;
|
||||||
ggml_tensor * K = dst->src[1];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
ggml_tensor * V = dst->src[2];
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
GGML_ASSERT(K->type == type_K);
|
GGML_ASSERT(K->type == type_K);
|
||||||
GGML_ASSERT(V->type == type_V);
|
GGML_ASSERT(V->type == type_V);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 2) {
|
if (Q->ne[1] == 2) {
|
||||||
constexpr int cols_per_block = 2;
|
constexpr int cols_per_block = 2;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 4) {
|
if (Q->ne[1] <= 4) {
|
||||||
constexpr int cols_per_block = 4;
|
constexpr int cols_per_block = 4;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8) {
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
#endif // FP16_MMA_AVAILABLE
|
#endif // FP16_MMA_AVAILABLE
|
||||||
|
|
||||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -22,6 +22,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_MMA_AVAILABLE
|
#ifdef FP16_MMA_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||||
|
@ -85,6 +92,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
const half2 slope2 = make_half2(slopef, slopef);
|
const half2 slope2 = make_half2(slopef, slopef);
|
||||||
|
|
||||||
|
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
|
||||||
|
|
||||||
frag_b Q_b[D/16][ncols/frag_n];
|
frag_b Q_b[D/16][ncols/frag_n];
|
||||||
|
|
||||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||||
|
@ -194,6 +203,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||||
|
@ -237,6 +250,15 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
// There is no dedicated tangens hyperbolicus function for half2.
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
||||||
|
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
||||||
|
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||||
|
@ -427,6 +449,7 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||||
|
|
||||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
constexpr int nwarps = 4;
|
constexpr int nwarps = 4;
|
||||||
|
@ -435,20 +458,50 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (4*blocks_num_pb1 < 2*nsm) {
|
if (4*blocks_num_pb1 < 2*nsm) {
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (2*blocks_num_pb1 < 2*nsm) {
|
if (2*blocks_num_pb1 < 2*nsm) {
|
||||||
constexpr int parallel_blocks = 2;
|
constexpr int parallel_blocks = 2;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
|
|
||||||
if (precision != GGML_PREC_DEFAULT) {
|
if (precision != GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||||
|
@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
|
|
||||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||||
if (cc >= CC_OFFSET_AMD) {
|
if (cc >= CC_OFFSET_AMD) {
|
||||||
|
|
|
@ -82,6 +82,8 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_NORM,
|
GGML_METAL_KERNEL_TYPE_NORM,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||||
|
@ -542,6 +544,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
||||||
|
@ -803,6 +807,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
case GGML_OP_SSM_SCAN:
|
||||||
|
return true;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
return ctx->support_simdgroup_reduction &&
|
return ctx->support_simdgroup_reduction &&
|
||||||
|
@ -1538,6 +1545,121 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||||
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||||
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||||
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
||||||
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
||||||
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
||||||
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
||||||
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
||||||
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
||||||
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_SSM_SCAN:
|
||||||
|
{
|
||||||
|
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||||
|
struct ggml_tensor * src4 = gf->nodes[i]->src[4];
|
||||||
|
struct ggml_tensor * src5 = gf->nodes[i]->src[5];
|
||||||
|
|
||||||
|
GGML_ASSERT(src3);
|
||||||
|
GGML_ASSERT(src4);
|
||||||
|
GGML_ASSERT(src5);
|
||||||
|
|
||||||
|
size_t offs_src3 = 0;
|
||||||
|
size_t offs_src4 = 0;
|
||||||
|
size_t offs_src5 = 0;
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||||
|
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
||||||
|
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
||||||
|
|
||||||
|
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
|
||||||
|
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
||||||
|
|
||||||
|
const uint64_t nb30 = src3->nb[0];
|
||||||
|
const uint64_t nb31 = src3->nb[1];
|
||||||
|
|
||||||
|
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
||||||
|
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
|
||||||
|
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
||||||
|
|
||||||
|
const uint64_t nb40 = src4->nb[0];
|
||||||
|
const uint64_t nb41 = src4->nb[1];
|
||||||
|
const uint64_t nb42 = src4->nb[2];
|
||||||
|
|
||||||
|
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
||||||
|
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
||||||
|
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
||||||
|
|
||||||
|
const uint64_t nb50 = src5->nb[0];
|
||||||
|
const uint64_t nb51 = src5->nb[1];
|
||||||
|
const uint64_t nb52 = src5->nb[2];
|
||||||
|
|
||||||
|
const int64_t d_state = ne00;
|
||||||
|
const int64_t d_inner = ne01;
|
||||||
|
const int64_t n_seq_tokens = ne11;
|
||||||
|
const int64_t n_seqs = ne02;
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||||
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||||
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||||
|
|
||||||
|
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
|
||||||
|
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
|
||||||
|
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
|
||||||
|
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
|
||||||
|
|
||||||
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
|
||||||
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
|
||||||
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
|
||||||
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||||
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||||
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||||
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
||||||
|
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
|
||||||
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
|
||||||
|
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
|
||||||
|
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
|
||||||
|
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
||||||
|
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
|
||||||
|
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
||||||
|
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
||||||
|
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
|
||||||
|
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
||||||
|
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
@ -2624,9 +2746,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
|
float logit_softcap;
|
||||||
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||||
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||||
|
memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
scale /= logit_softcap;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_head = src0->ne[2];
|
const uint32_t n_head = src0->ne[2];
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
@ -2701,6 +2828,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
||||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
||||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
||||||
|
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
||||||
|
|
||||||
if (!use_vec_kernel) {
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
|
|
|
@ -667,6 +667,127 @@ kernel void kernel_diag_mask_inf_8(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
||||||
|
// TODO: optimize
|
||||||
|
kernel void kernel_ssm_conv_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const void * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t ir = tgpig.x;
|
||||||
|
const int64_t i2 = tgpig.y;
|
||||||
|
const int64_t i3 = tgpig.z;
|
||||||
|
|
||||||
|
const int64_t nc = ne10;
|
||||||
|
const int64_t ncs = ne00;
|
||||||
|
const int64_t nr = ne01;
|
||||||
|
const int64_t n_t = ne1;
|
||||||
|
const int64_t n_s = ne2;
|
||||||
|
|
||||||
|
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
||||||
|
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
||||||
|
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||||
|
sumf += s[i0] * c[i0];
|
||||||
|
}
|
||||||
|
|
||||||
|
x[0] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
||||||
|
// TODO: optimize
|
||||||
|
kernel void kernel_ssm_scan_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const void * src1,
|
||||||
|
device const void * src2,
|
||||||
|
device const void * src3,
|
||||||
|
device const void * src4,
|
||||||
|
device const void * src5,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & d_state,
|
||||||
|
constant int64_t & d_inner,
|
||||||
|
constant int64_t & n_seq_tokens,
|
||||||
|
constant int64_t & n_seqs,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant uint64_t & nb13,
|
||||||
|
constant uint64_t & nb20,
|
||||||
|
constant uint64_t & nb21,
|
||||||
|
constant uint64_t & nb22,
|
||||||
|
constant uint64_t & nb30,
|
||||||
|
constant uint64_t & nb31,
|
||||||
|
constant uint64_t & nb40,
|
||||||
|
constant uint64_t & nb41,
|
||||||
|
constant uint64_t & nb42,
|
||||||
|
constant uint64_t & nb50,
|
||||||
|
constant uint64_t & nb51,
|
||||||
|
constant uint64_t & nb52,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t ir = tgpig.x;
|
||||||
|
const int64_t i3 = tgpig.y;
|
||||||
|
|
||||||
|
const int64_t nc = d_state;
|
||||||
|
const int64_t nr = d_inner;
|
||||||
|
const int64_t n_t = n_seq_tokens;
|
||||||
|
const int64_t n_s = n_seqs;
|
||||||
|
|
||||||
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
|
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
|
||||||
|
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
|
||||||
|
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
|
||||||
|
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
|
||||||
|
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
|
||||||
|
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
|
||||||
|
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
|
||||||
|
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
|
||||||
|
|
||||||
|
if (i2 > 0) {
|
||||||
|
s0 = s;
|
||||||
|
}
|
||||||
|
|
||||||
|
// i1 == 0
|
||||||
|
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
|
float x_dt = x[0] * dt_soft_plus;
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||||
|
int64_t i = i0;
|
||||||
|
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||||
|
sumf += state * C[i0];
|
||||||
|
s[i] = state;
|
||||||
|
}
|
||||||
|
|
||||||
|
y[0] = sumf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_norm(
|
kernel void kernel_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1976,6 +2097,7 @@ typedef void (flash_attn_ext_f16_t)(
|
||||||
constant float & m0,
|
constant float & m0,
|
||||||
constant float & m1,
|
constant float & m1,
|
||||||
constant uint32_t & n_head_log2,
|
constant uint32_t & n_head_log2,
|
||||||
|
constant float & logit_softcap,
|
||||||
threadgroup half * shared,
|
threadgroup half * shared,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
@ -2014,6 +2136,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
constant float & m0,
|
constant float & m0,
|
||||||
constant float & m1,
|
constant float & m1,
|
||||||
constant uint32_t & n_head_log2,
|
constant uint32_t & n_head_log2,
|
||||||
|
constant float & logit_softcap,
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
@ -2138,19 +2261,6 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||||
|
|
||||||
const short tx = tiisg%4;
|
|
||||||
const short ty = tiisg/4;
|
|
||||||
|
|
||||||
if (mask != q) {
|
|
||||||
// mqk = mqk*scale + mask*slope
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
|
||||||
} else {
|
|
||||||
// mqk = mqk*scale
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2162,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
float ms[Q];
|
float ms[Q];
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const short p = tiisg;
|
|
||||||
|
|
||||||
const float m = M[j];
|
const float m = M[j];
|
||||||
const float s = ss[j*TF + p];
|
|
||||||
|
// scale and apply the logitcap / mask
|
||||||
|
float s = ss[j*TF + tiisg]*scale;
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
s = logit_softcap*precise::tanh(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mask != q) {
|
||||||
|
// mqk = mqk + mask*slope
|
||||||
|
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
||||||
|
}
|
||||||
|
|
||||||
smax = simd_max(max(smax, s));
|
smax = simd_max(max(smax, s));
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
@ -2176,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*TF + p] = vs;
|
ss[j*TF + tiisg] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
@ -2345,6 +2464,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
constant float & m0,
|
constant float & m0,
|
||||||
constant float & m1,
|
constant float & m1,
|
||||||
constant uint32_t & n_head_log2,
|
constant uint32_t & n_head_log2,
|
||||||
|
constant float & logit_softcap,
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
@ -2479,7 +2599,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
|
|
||||||
// mqk = mqk*scale + mask*slope
|
// mqk = mqk*scale + mask*slope
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
|
mqk *= scale;
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
mqk = logit_softcap*precise::tanh(mqk);
|
||||||
|
}
|
||||||
|
|
||||||
|
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
|
||||||
|
|
||||||
ss4[cc] = mqk;
|
ss4[cc] = mqk;
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,7 @@
|
||||||
|
|
||||||
#include "ggml-sycl/backend.hpp"
|
#include "ggml-sycl/backend.hpp"
|
||||||
#include "ggml-sycl/presets.hpp"
|
#include "ggml-sycl/presets.hpp"
|
||||||
|
#include "ggml-sycl/gemm.hpp"
|
||||||
|
|
||||||
bool ggml_sycl_loaded(void);
|
bool ggml_sycl_loaded(void);
|
||||||
void ggml_sycl_free_data(struct ggml_tensor * tensor);
|
void ggml_sycl_free_data(struct ggml_tensor * tensor);
|
||||||
|
@ -2482,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
|
|
||||||
const sycl::half alpha_f16 = 1.0f;
|
const sycl::half alpha_f16 = 1.0f;
|
||||||
const sycl::half beta_f16 = 0.0f;
|
const sycl::half beta_f16 = 0.0f;
|
||||||
|
#if !GGML_SYCL_DNNL
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||||
*stream, oneapi::mkl::transpose::trans,
|
*stream, oneapi::mkl::transpose::trans,
|
||||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
|
@ -2491,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
dpct::library_data_t::real_half)));
|
dpct::library_data_t::real_half)));
|
||||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
||||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||||
|
#else
|
||||||
|
auto dnnl_stream = ctx.stream_dnnl(stream);
|
||||||
|
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||||
|
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
|
||||||
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
||||||
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
||||||
|
@ -2513,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
|
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
|
#if !GGML_SYCL_DNNL
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||||
*stream, oneapi::mkl::transpose::trans,
|
*stream, oneapi::mkl::transpose::trans,
|
||||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
|
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
|
||||||
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
||||||
dst_dd_i, ldc)));
|
dst_dd_i, ldc)));
|
||||||
|
#else
|
||||||
|
auto dnnl_stream = ctx.stream_dnnl(stream);
|
||||||
|
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
||||||
|
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
(void) dst;
|
(void) dst;
|
||||||
(void) src1_ddq_i;
|
(void) src1_ddq_i;
|
||||||
|
|
|
@ -19,6 +19,10 @@
|
||||||
#include "dpct/helper.hpp"
|
#include "dpct/helper.hpp"
|
||||||
#include "ggml-sycl.h"
|
#include "ggml-sycl.h"
|
||||||
#include "presets.hpp"
|
#include "presets.hpp"
|
||||||
|
#if GGML_SYCL_DNNL
|
||||||
|
#include "dnnl.hpp"
|
||||||
|
#include "dnnl_sycl.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
#define GGML_COMMON_DECL_SYCL
|
#define GGML_COMMON_DECL_SYCL
|
||||||
#define GGML_COMMON_IMPL_SYCL
|
#define GGML_COMMON_IMPL_SYCL
|
||||||
|
@ -277,6 +281,52 @@ struct ggml_backend_sycl_context {
|
||||||
return stream(device, 0);
|
return stream(device, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if GGML_SYCL_DNNL
|
||||||
|
dnnl::engine make_engine(sycl::queue* q) {
|
||||||
|
// Get the device associated with the queue
|
||||||
|
sycl::device dev = q->get_device();
|
||||||
|
// Get the context associated with the queue
|
||||||
|
sycl::context ctx = q->get_context();
|
||||||
|
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
||||||
|
return eng;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
|
||||||
|
std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
|
||||||
|
dnnl::stream stream_dnnl(int device, int _stream) {
|
||||||
|
auto q = stream(device, _stream);
|
||||||
|
return stream_dnnl(q);
|
||||||
|
}
|
||||||
|
dnnl::engine engine_dnnl(sycl::queue* qptr) {
|
||||||
|
auto it = engine_map.find(qptr);
|
||||||
|
if (it == engine_map.end()) {
|
||||||
|
auto eng = make_engine(qptr);
|
||||||
|
engine_map[qptr] = eng;
|
||||||
|
return eng;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dnnl::stream stream_dnnl(sycl::queue* qptr) {
|
||||||
|
auto it = stream_map.find(qptr);
|
||||||
|
if (it == stream_map.end()) {
|
||||||
|
auto eng = engine_dnnl(qptr);
|
||||||
|
auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
|
||||||
|
stream_map[qptr] = stream;
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dnnl::stream stream_dnnl() {
|
||||||
|
return stream_dnnl(device, 0);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// pool
|
// pool
|
||||||
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
||||||
|
|
||||||
|
|
101
ggml/src/ggml-sycl/gemm.hpp
Normal file
101
ggml/src/ggml-sycl/gemm.hpp
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_GEMM_HPP
|
||||||
|
#define GGML_SYCL_GEMM_HPP
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "ggml-sycl.h"
|
||||||
|
|
||||||
|
#if GGML_SYCL_DNNL
|
||||||
|
|
||||||
|
#include "dnnl.hpp"
|
||||||
|
#include "dnnl_sycl.hpp"
|
||||||
|
|
||||||
|
class DnnlGemmWrapper {
|
||||||
|
public:
|
||||||
|
using dt = dnnl::memory::data_type;
|
||||||
|
using tag = dnnl::memory::format_tag;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static constexpr dt to_dt() {
|
||||||
|
if constexpr (std::is_same_v<T, float>) return dt::f32;
|
||||||
|
else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
|
||||||
|
else static_assert(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void row_gemm(sycl::queue& q, bool a_trans,
|
||||||
|
bool b_trans, int m, int n, int k,
|
||||||
|
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
||||||
|
{
|
||||||
|
// Get the device associated with the queue
|
||||||
|
sycl::device dev = q.get_device();
|
||||||
|
// Get the context associated with the queue
|
||||||
|
sycl::context ctx = q.get_context();
|
||||||
|
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
||||||
|
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
||||||
|
dnnl::memory::dims a_dims = { m, k };
|
||||||
|
dnnl::memory::dims b_dims = { k, n };
|
||||||
|
dnnl::memory::dims c_dims = { m, n };
|
||||||
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||||
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||||
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||||
|
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
||||||
|
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
||||||
|
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||||
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||||
|
|
||||||
|
// Create the primitive.
|
||||||
|
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||||
|
// Primitive arguments.
|
||||||
|
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||||
|
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||||
|
|
||||||
|
matmul_prim.execute(stream, matmul_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
|
||||||
|
bool b_trans, int m, int n, int k,
|
||||||
|
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
||||||
|
{
|
||||||
|
auto const eng = stream.get_engine();
|
||||||
|
dnnl::memory::dims a_dims = { m, k };
|
||||||
|
dnnl::memory::dims b_dims = { k, n };
|
||||||
|
dnnl::memory::dims c_dims = { m, n };
|
||||||
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||||
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||||
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||||
|
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
||||||
|
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
||||||
|
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||||
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||||
|
|
||||||
|
// Create the primitive.
|
||||||
|
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||||
|
// Primitive arguments.
|
||||||
|
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||||
|
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||||
|
|
||||||
|
matmul_prim.execute(stream, matmul_args);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_GEMM_HPP
|
236
ggml/src/ggml.c
236
ggml/src/ggml.c
|
@ -7119,7 +7119,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias) {
|
float max_bias,
|
||||||
|
float logit_softcap) {
|
||||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
// TODO: check if vT can be multiplied by (k*qT)
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
|
||||||
|
@ -7146,7 +7147,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
float params[] = { scale, max_bias };
|
float params[] = { scale, max_bias, logit_softcap };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||||
|
@ -7166,7 +7167,7 @@ void ggml_flash_attn_ext_set_prec(
|
||||||
|
|
||||||
const int32_t prec_i32 = (int32_t) prec;
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_flash_attn_back
|
// ggml_flash_attn_back
|
||||||
|
@ -7253,43 +7254,34 @@ struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
|
||||||
struct ggml_tensor * ggml_ssm_conv(
|
struct ggml_tensor * ggml_ssm_conv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * sx,
|
||||||
struct ggml_tensor * x,
|
struct ggml_tensor * c) {
|
||||||
struct ggml_tensor * c,
|
GGML_ASSERT(ggml_is_3d(sx));
|
||||||
struct ggml_tensor * sq) {
|
|
||||||
GGML_ASSERT(ggml_is_3d(s));
|
|
||||||
GGML_ASSERT(ggml_is_matrix(x));
|
|
||||||
GGML_ASSERT(ggml_is_matrix(c));
|
GGML_ASSERT(ggml_is_matrix(c));
|
||||||
GGML_ASSERT(ggml_is_matrix(sq));
|
|
||||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
|
||||||
|
|
||||||
const int64_t d_conv = c->ne[0];
|
const int64_t d_conv = c->ne[0];
|
||||||
const int64_t d_inner = c->ne[1];
|
const int64_t d_inner = c->ne[1];
|
||||||
const int64_t n_tokens = x->ne[1];
|
const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
|
||||||
const int64_t n_kv = s->ne[2];
|
const int64_t n_s = sx->ne[2];
|
||||||
|
|
||||||
GGML_ASSERT( s->ne[0] == d_conv - 1);
|
// TODO: maybe support other strides than 1?
|
||||||
GGML_ASSERT( s->ne[1] == d_inner);
|
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
||||||
GGML_ASSERT( x->ne[0] == d_inner);
|
GGML_ASSERT(sx->ne[1] == d_inner);
|
||||||
GGML_ASSERT(sq->ne[0] == n_kv);
|
GGML_ASSERT(n_t >= 0);
|
||||||
GGML_ASSERT(sq->ne[1] == n_tokens);
|
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (s->grad || x->grad || c->grad || sq->grad) {
|
if (sx->grad || c->grad) {
|
||||||
GGML_ABORT("fatal error"); // TODO: implement
|
GGML_ABORT("fatal error"); // TODO: implement
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
|
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
|
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_CONV;
|
result->op = GGML_OP_SSM_CONV;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = s;
|
result->src[0] = sx;
|
||||||
result->src[1] = x;
|
result->src[1] = c;
|
||||||
result->src[2] = c;
|
|
||||||
result->src[3] = sq;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -7303,39 +7295,42 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_tensor * dt,
|
struct ggml_tensor * dt,
|
||||||
struct ggml_tensor * A,
|
struct ggml_tensor * A,
|
||||||
struct ggml_tensor * B,
|
struct ggml_tensor * B,
|
||||||
struct ggml_tensor * C,
|
struct ggml_tensor * C) {
|
||||||
struct ggml_tensor * sq) {
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(s));
|
GGML_ASSERT(ggml_is_contiguous(s));
|
||||||
GGML_ASSERT(ggml_is_contiguous(x));
|
GGML_ASSERT(ggml_is_contiguous(x));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||||
GGML_ASSERT(ggml_is_contiguous(A));
|
GGML_ASSERT(ggml_is_contiguous(A));
|
||||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
GGML_ASSERT(ggml_is_matrix(A));
|
||||||
|
GGML_ASSERT(ggml_is_3d(B));
|
||||||
|
GGML_ASSERT(ggml_is_3d(s));
|
||||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||||
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
||||||
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(B, C));
|
||||||
|
|
||||||
{
|
{
|
||||||
const int64_t d_state = s->ne[0];
|
const int64_t d_state = s->ne[0];
|
||||||
const int64_t d_inner = s->ne[1];
|
const int64_t d_inner = s->ne[1];
|
||||||
const int64_t n_tokens = x->ne[1];
|
const int64_t n_seq_tokens = x->ne[1];
|
||||||
|
const int64_t n_seqs = x->ne[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(s->ne[2] == n_seqs);
|
||||||
GGML_ASSERT(x->ne[0] == d_inner);
|
GGML_ASSERT(x->ne[0] == d_inner);
|
||||||
GGML_ASSERT(A->ne[0] == d_state);
|
GGML_ASSERT(A->ne[0] == d_state);
|
||||||
GGML_ASSERT(A->ne[1] == d_inner);
|
GGML_ASSERT(A->ne[1] == d_inner);
|
||||||
GGML_ASSERT(B->ne[0] == d_state);
|
GGML_ASSERT(B->ne[0] == d_state);
|
||||||
GGML_ASSERT(B->ne[1] == n_tokens);
|
GGML_ASSERT(B->ne[1] == n_seq_tokens);
|
||||||
GGML_ASSERT(C->ne[0] == d_state);
|
GGML_ASSERT(B->ne[2] == n_seqs);
|
||||||
GGML_ASSERT(C->ne[1] == n_tokens);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
|
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
|
||||||
GGML_ABORT("fatal error"); // TODO: implement
|
GGML_ABORT("fatal error"); // TODO: implement
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
|
// concatenated y + ssm_states
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_SCAN;
|
result->op = GGML_OP_SSM_SCAN;
|
||||||
|
@ -7346,7 +7341,6 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
result->src[3] = A;
|
result->src[3] = A;
|
||||||
result->src[4] = B;
|
result->src[4] = B;
|
||||||
result->src[5] = C;
|
result->src[5] = C;
|
||||||
result->src[6] = sq;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -11041,11 +11035,6 @@ static void ggml_compute_forward_concat_f32(
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
// TODO: support for transposed / permuted tensors
|
|
||||||
GGML_ASSERT(nb0 == sizeof(float));
|
|
||||||
GGML_ASSERT(nb00 == sizeof(float));
|
|
||||||
GGML_ASSERT(nb10 == sizeof(float));
|
|
||||||
|
|
||||||
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
||||||
|
|
||||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
|
@ -15339,9 +15328,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
|
float logit_softcap = 0.0f;
|
||||||
|
|
||||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
if (logit_softcap != 0) {
|
||||||
|
scale /= logit_softcap;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_head = neq2;
|
const uint32_t n_head = neq2;
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||||
|
@ -15405,7 +15400,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||||
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||||
|
|
||||||
s = s*scale + mv; // scale KQ value and apply mask
|
s = s*scale; // scale KQ value
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
s = logit_softcap*tanhf(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
s += mv; // apply mask
|
||||||
|
|
||||||
const float Mold = M;
|
const float Mold = M;
|
||||||
|
|
||||||
|
@ -15481,7 +15482,7 @@ static void ggml_compute_forward_flash_attn_ext(
|
||||||
const struct ggml_tensor * v,
|
const struct ggml_tensor * v,
|
||||||
const struct ggml_tensor * mask,
|
const struct ggml_tensor * mask,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (dst->op_params[2]) {
|
switch (dst->op_params[3]) {
|
||||||
case GGML_PREC_DEFAULT:
|
case GGML_PREC_DEFAULT:
|
||||||
case GGML_PREC_F32:
|
case GGML_PREC_F32:
|
||||||
{
|
{
|
||||||
|
@ -15836,27 +15837,22 @@ static void ggml_compute_forward_flash_attn_back(
|
||||||
static void ggml_compute_forward_ssm_conv_f32(
|
static void ggml_compute_forward_ssm_conv_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
|
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
||||||
const struct ggml_tensor * src1 = dst->src[1]; // x
|
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
||||||
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
|
||||||
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int nc = src2->ne[0]; // d_conv
|
const int nc = src1->ne[0]; // d_conv
|
||||||
|
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
|
||||||
const int nr = src0->ne[1]; // d_inner
|
const int nr = src0->ne[1]; // d_inner
|
||||||
const int n_t = src1->ne[1]; // n_tokens
|
const int n_t = dst->ne[1]; // tokens per sequence
|
||||||
const int n_kv = src0->ne[2]; // max number of sequences in the batch
|
const int n_s = dst->ne[2]; // number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
|
GGML_ASSERT( dst->ne[0] == nr);
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
||||||
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
|
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
// for use with the destination state offset between sequences
|
|
||||||
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
@ -15866,77 +15862,30 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
const int ir = ir1 - ir0;
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
if (n_kv > 1) {
|
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||||
// multiple sequences means it's hard to know when it's the first time a state is read,
|
|
||||||
// so copy them all over to the destination, just to be sure.
|
|
||||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
|
||||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
|
||||||
// can't use memcpy because of d_conv vs d_conv - 1
|
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
|
||||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
|
||||||
// copy s0 to last (d_conv - 1) columns of s
|
|
||||||
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
// {d_conv - 1 + n_t, d_inner, n_seqs}
|
||||||
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
// sliding window
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
|
const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
|
||||||
float * s0; // {d_conv - 1, d_inner, n_kv}
|
const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
|
||||||
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
|
||||||
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
|
||||||
int ne0s0;
|
|
||||||
|
|
||||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
|
||||||
|
|
||||||
// avoid needing to copy the state for the first token
|
|
||||||
if (i2 == 0) {
|
|
||||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
|
|
||||||
ne0s0 = src0->ne[0];
|
|
||||||
} else {
|
|
||||||
// the source is the last (d_conv - 1) columns of the destination
|
|
||||||
s0 = s + 1;
|
|
||||||
ne0s0 = nc;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// TODO: transpose the output for smaller strides for big batches?
|
||||||
// d_inner
|
// d_inner
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
|
||||||
// shift state left
|
|
||||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
|
||||||
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
|
|
||||||
}
|
|
||||||
// insert x on the last column
|
|
||||||
s[(nc - 1) + i1*nc] = x0[i1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle copies when there are multiple output states
|
|
||||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
|
||||||
int32_t seq = sq[i3];
|
|
||||||
if (0 <= seq && seq < n_kv) {
|
|
||||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
|
||||||
memcpy(s1, s, nc*ir*sizeof(float));
|
|
||||||
} else {
|
|
||||||
// stop at negative or too big seq_ids
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// it seems a little faster when this is separate from the state shift
|
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
// rowwise dot product
|
// rowwise dot product
|
||||||
|
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
// d_conv
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
int i = i0 + i1*nc;
|
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
|
||||||
sumf += s[i] * c[i];
|
|
||||||
}
|
}
|
||||||
x[i1] = sumf;
|
x[i1] = sumf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_compute_forward_ssm_conv(
|
static void ggml_compute_forward_ssm_conv(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
|
@ -15964,15 +15913,14 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const struct ggml_tensor * src3 = dst->src[3]; // A
|
const struct ggml_tensor * src3 = dst->src[3]; // A
|
||||||
const struct ggml_tensor * src4 = dst->src[4]; // B
|
const struct ggml_tensor * src4 = dst->src[4]; // B
|
||||||
const struct ggml_tensor * src5 = dst->src[5]; // C
|
const struct ggml_tensor * src5 = dst->src[5]; // C
|
||||||
const struct ggml_tensor * src6 = dst->src[6]; // sq
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int64_t nc = src0->ne[0]; // d_state
|
const int64_t nc = src0->ne[0]; // d_state
|
||||||
const int64_t nr = src0->ne[1]; // d_inner
|
const int64_t nr = src0->ne[1]; // d_inner
|
||||||
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
||||||
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
|
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
@ -15981,12 +15929,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||||
// required for the dot product between s and C, and when copying the states
|
// required for the dot product between s and C
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
// required for per-sequence offsets for states
|
// required for per-sequence offsets for states
|
||||||
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
||||||
// required to get correct offset for state destination (i.e. src1->nb[2])
|
// required to get correct offset for state destination (i.e. src1->nb[3])
|
||||||
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
|
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
@ -15996,36 +15944,19 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
const int ir = ir1 - ir0;
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
if (n_kv > 1) {
|
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||||
// it's hard to know if the source states have already been copied
|
|
||||||
// when there are multiple, so copy them already.
|
|
||||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
|
||||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
|
||||||
memcpy(s, s0, nc*ir*sizeof(float));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
|
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
||||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
||||||
float * s0;
|
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
||||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
||||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
||||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
|
||||||
|
|
||||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
// use the output as the source for the next token-wise iterations
|
||||||
|
if (i2 > 0) { s0 = s; }
|
||||||
// avoid needing to copy the state for the first token
|
|
||||||
if (i2 == 0) {
|
|
||||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
|
|
||||||
} else {
|
|
||||||
// otherwise the source is the same as the destination
|
|
||||||
s0 = s;
|
|
||||||
}
|
|
||||||
|
|
||||||
// d_inner
|
// d_inner
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
|
@ -16044,17 +15975,6 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
}
|
}
|
||||||
y[i1] = sumf;
|
y[i1] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle copies when there are multiple output states
|
|
||||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
|
||||||
int32_t seq = sq[i3];
|
|
||||||
if (0 <= seq && seq < n_kv) {
|
|
||||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
|
||||||
memcpy(s1, s, nc*ir*sizeof(float));
|
|
||||||
} else {
|
|
||||||
// stop at negative or too big seq_ids
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -511,6 +511,9 @@ extern "C" {
|
||||||
// to the decoder to start generating output sequence. For other models, it returns -1.
|
// to the decoder to start generating output sequence. For other models, it returns -1.
|
||||||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
|
||||||
|
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
|
||||||
|
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
LLAMA_API uint32_t llama_model_quantize(
|
LLAMA_API uint32_t llama_model_quantize(
|
||||||
const char * fname_inp,
|
const char * fname_inp,
|
||||||
|
|
|
@ -2008,7 +2008,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
ggml_element_size(kv_pad.v)*n_state_head,
|
ggml_element_size(kv_pad.v)*n_state_head,
|
||||||
0);
|
0);
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f);
|
cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
|
||||||
|
|
||||||
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
||||||
} else {
|
} else {
|
||||||
|
@ -2471,7 +2471,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
ggml_element_size(kv_self.v)*n_state_head,
|
ggml_element_size(kv_self.v)*n_state_head,
|
||||||
ggml_element_size(kv_self.v)*n_state*n_ctx*il);
|
ggml_element_size(kv_self.v)*n_state*n_ctx*il);
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f);
|
cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
|
||||||
|
|
||||||
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
||||||
} else {
|
} else {
|
||||||
|
@ -2553,7 +2553,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
||||||
ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
|
ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f);
|
cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
|
||||||
|
|
||||||
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -31,11 +31,17 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
|
||||||
|
|
||||||
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||||
if (search.empty()) {
|
if (search.empty()) {
|
||||||
return; // Avoid infinite loop if 'search' is an empty string
|
return;
|
||||||
}
|
}
|
||||||
|
std::string builder;
|
||||||
|
builder.reserve(s.length());
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
while ((pos = s.find(search, pos)) != std::string::npos) {
|
size_t last_pos = 0;
|
||||||
s.replace(pos, search.length(), replace);
|
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
||||||
pos += replace.length();
|
builder.append(s, last_pos, pos - last_pos);
|
||||||
|
builder.append(replace);
|
||||||
|
last_pos = pos + search.length();
|
||||||
}
|
}
|
||||||
|
builder.append(s, last_pos, std::string::npos);
|
||||||
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
1466
src/llama.cpp
1466
src/llama.cpp
File diff suppressed because it is too large
Load diff
|
@ -14,7 +14,7 @@ MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO
|
||||||
# Clone the Hugging Face repository if the directory does not exist
|
# Clone the Hugging Face repository if the directory does not exist
|
||||||
if [ ! -d "$MODELS_REPO" ]; then
|
if [ ! -d "$MODELS_REPO" ]; then
|
||||||
echo "Cloning the Hugging Face repository..."
|
echo "Cloning the Hugging Face repository..."
|
||||||
git clone $MODELS_REPO_URL
|
git clone $MODELS_REPO_URL --depth 1
|
||||||
else
|
else
|
||||||
echo "Repository already exists. Skipping clone."
|
echo "Repository already exists. Skipping clone."
|
||||||
fi
|
fi
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue