Alone in the darkness

They're coming for you
I know they will try to catch me too
Alone in the darkness
They're calling for you
There's nowhere to run for cover
This commit is contained in:
Concedo 2024-10-24 22:29:20 +08:00
commit 94a5a27b85
44 changed files with 6803 additions and 2143 deletions

View file

@ -1098,7 +1098,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
add_opt(common_arg( add_opt(common_arg(
{"--attention"}, "{causal,non,causal}", {"--attention"}, "{causal,non-causal}",
"attention type for embeddings, use model default if unspecified", "attention type for embeddings, use model default if unspecified",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
@ -1696,7 +1696,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_BENCH})); ).set_examples({LLAMA_EXAMPLE_BENCH}));
add_opt(common_arg( add_opt(common_arg(
{"--embd-normalize"}, "N", {"--embd-normalize"}, "N",
string_format("normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize), string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize),
[](common_params & params, int value) { [](common_params & params, int value) {
params.embd_normalize = value; params.embd_normalize = value;
} }
@ -1710,7 +1710,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_EMBEDDING})); ).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg( add_opt(common_arg(
{"--embd-separator"}, "STRING", {"--embd-separator"}, "STRING",
"separator of embendings (default \\n) for example \"<#sep#>\"", "separator of embeddings (default \\n) for example \"<#sep#>\"",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.embd_sep = value; params.embd_sep = value;
} }

View file

@ -957,7 +957,7 @@ struct common_init_result common_init_from_params(common_params & params) {
} }
if (llama_model_has_encoder(model)) { if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model); llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) { if (decoder_start_token_id == -1) {
decoder_start_token_id = bos; decoder_start_token_id = bos;
@ -966,7 +966,7 @@ struct common_init_result common_init_from_params(common_params & params) {
tmp.push_back(decoder_start_token_id); tmp.push_back(decoder_start_token_id);
} }
if (llama_model_has_decoder(model)) { if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
} }
llama_kv_cache_clear(lctx); llama_kv_cache_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);
@ -1037,7 +1037,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
return GGML_TYPE_Q5_1; return GGML_TYPE_Q5_1;
} }
throw std::runtime_error("Invalid cache type: " + s); throw std::runtime_error("Unsupported cache type: " + s);
} }
struct llama_context_params common_context_params_to_llama(const common_params & params) { struct llama_context_params common_context_params_to_llama(const common_params & params) {

View file

@ -270,9 +270,9 @@ struct common_params {
// embedding // embedding
bool embedding = false; // get only sentence embedding bool embedding = false; // get only sentence embedding
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embendings std::string embd_sep = "\n"; // separator of embeddings
bool reranking = false; // enable reranking support on server bool reranking = false; // enable reranking support on server
// server params // server params

View file

@ -171,7 +171,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.penalize_nl, params.penalize_nl,
params.ignore_eos)); params.ignore_eos));
if (params.temp > 0.0f) {
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) { for (const auto & cnstr : params.samplers) {
switch (cnstr) { switch (cnstr) {
@ -203,7 +202,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ASSERT(false && "unknown sampler type"); GGML_ASSERT(false && "unknown sampler type");
} }
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
@ -214,18 +212,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
} else { } else {
GGML_ASSERT(false && "unknown mirostat version"); GGML_ASSERT(false && "unknown mirostat version");
} }
} else {
if (params.n_probs > 0) {
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
//
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
// it is much faster, since we avoid sorting all tokens and should give a good approximation
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
}
return result; return result;
} }

View file

@ -2864,6 +2864,9 @@ class Rwkv6Model(Model):
self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes) self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self): def set_gguf_parameters(self):

View file

@ -348,6 +348,9 @@ if __name__ == '__main__':
if ".base_layer.weight" in name: if ".base_layer.weight" in name:
continue continue
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()")
sys.exit(1) sys.exit(1)
if base_name in tensor_map: if base_name in tensor_map:

View file

@ -74,7 +74,6 @@ int main(int argc, char ** argv) {
batch.n_seq_id + i, batch.n_seq_id + i,
batch.seq_id + i, batch.seq_id + i,
batch.logits + i, batch.logits + i,
0, 0, 0, // unused
}; };
const int ret = llama_decode(ctx, batch_view); const int ret = llama_decode(ctx, batch_view);

View file

@ -339,7 +339,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) { static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;
} }

View file

@ -131,7 +131,7 @@ static bool run(llama_context * ctx, const common_params & params) {
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return false; return false;
} }

View file

@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr,
0,
0,
0,
}; };
if (embd) { if (embd) {

View file

@ -46,7 +46,6 @@ actor LlamaContext {
let sparams = llama_sampler_chain_default_params() let sparams = llama_sampler_chain_default_params()
self.sampling = llama_sampler_chain_init(sparams) self.sampling = llama_sampler_chain_init(sparams)
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4)) llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234)) llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
} }

783
examples/llama.vim Normal file
View file

@ -0,0 +1,783 @@
" LLM-based text completion using llama.cpp
"
" requires:
"
" - neovim or vim
" - curl
" - llama.cpp server instance
" - FIM-compatible model
"
" sample config:
"
" - Tab - accept the current suggestion
" - Shift+Tab - accept just the first line of the suggestion
" - Ctrl+F - toggle FIM completion manually
"
" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim
"
" start the llama.cpp server with a FIM-compatible model. for example:
"
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256
"
" --batch-size [512, model max context]
"
" adjust the batch size to control how much of the provided local context will be used during the inference
" lower values will use smaller part of the context around the cursor, which will result in faster processing
"
" --ubatch-size [64, 2048]
"
" chunks the batch into smaller chunks for faster processing
" depends on the specific hardware. use llama-bench to profile and determine the best size
"
" --cache-reuse (ge:llama_config.n_predict, 1024]
"
" this should be either 0 (disabled) or strictly larger than g:llama_config.n_predict
" using non-zero value enables context reuse on the server side which dramatically improves the performance at
" large contexts. a value of 256 should be good for all cases
"
" run this once to initialise llama.vim:
"
" :call llama#init()
"
" more info: https://github.com/ggerganov/llama.cpp/pull/9787
"
" colors (adjust to your liking)
highlight llama_hl_hint guifg=#ff772f ctermfg=202
highlight llama_hl_info guifg=#77ff2f ctermfg=119
" general parameters:
"
" endpoint: llama.cpp server endpoint
" n_prefix: number of lines before the cursor location to include in the local prefix
" n_suffix: number of lines after the cursor location to include in the local suffix
" n_predict: max number of tokens to predict
" t_max_prompt_ms: max alloted time for the prompt processing (TODO: not yet supported)
" t_max_predict_ms: max alloted time for the prediction
" show_info: show extra info about the inference (0 - disabled, 1 - statusline, 2 - inline)
" auto_fim: trigger FIM completion automatically on cursor movement
" max_line_suffix: do not auto-trigger FIM completion if there are more than this number of characters to the right of the cursor
"
" ring buffer of chunks, accumulated with time upon:
"
" - completion request
" - yank
" - entering a buffer
" - leaving a buffer
" - writing a file
"
" parameters for the ring-buffer with extra context:
"
" ring_n_chunks: max number of chunks to pass as extra context to the server (0 to disable)
" ring_chunk_size: max size of the chunks (in number of lines)
" note: adjust these numbers so that you don't overrun your context
" at ring_n_chunks = 64 and ring_chunk_size = 64 you need ~32k context
" ring_scope: the range around the cursor position (in number of lines) for gathering chunks after FIM
" ring_update_ms: how often to process queued chunks in normal mode
"
let s:default_config = {
\ 'endpoint': 'http://127.0.0.1:8012/infill',
\ 'n_prefix': 256,
\ 'n_suffix': 64,
\ 'n_predict': 128,
\ 't_max_prompt_ms': 500,
\ 't_max_predict_ms': 3000,
\ 'show_info': 2,
\ 'auto_fim': v:true,
\ 'max_line_suffix': 8,
\ 'ring_n_chunks': 64,
\ 'ring_chunk_size': 64,
\ 'ring_scope': 1024,
\ 'ring_update_ms': 1000,
\ }
let g:llama_config = get(g:, 'llama_config', s:default_config)
function! s:get_indent(str)
let l:count = 0
for i in range(len(a:str))
if a:str[i] == "\t"
let l:count += &tabstop - 1
else
break
endif
endfor
return l:count
endfunction
function! s:rand(i0, i1) abort
return a:i0 + rand() % (a:i1 - a:i0 + 1)
endfunction
function! llama#init()
if !executable('curl')
echohl WarningMsg
echo 'llama.vim requires the "curl" command to be available'
echohl None
return
endif
let s:pos_x = 0 " cursor position upon start of completion
let s:pos_y = 0
let s:line_cur = ''
let s:line_cur_prefix = ''
let s:line_cur_suffix = ''
let s:ring_chunks = [] " current set of chunks used as extra context
let s:ring_queued = [] " chunks that are queued to be sent for processing
let s:ring_n_evict = 0
let s:hint_shown = v:false
let s:pos_y_pick = -9999 " last y where we picked a chunk
let s:pos_dx = 0
let s:content = []
let s:can_accept = v:false
let s:timer_fim = -1
let s:t_fim_start = reltime() " used to measure total FIM time
let s:t_last_move = reltime() " last time the cursor moved
let s:current_job = v:null
let s:ghost_text_nvim = exists('*nvim_buf_get_mark')
let s:ghost_text_vim = has('textprop')
if s:ghost_text_vim
let s:hlgroup_hint = 'llama_hl_hint'
let s:hlgroup_info = 'llama_hl_info'
if empty(prop_type_get(s:hlgroup_hint))
call prop_type_add(s:hlgroup_hint, {'highlight': s:hlgroup_hint})
endif
if empty(prop_type_get(s:hlgroup_info))
call prop_type_add(s:hlgroup_info, {'highlight': s:hlgroup_info})
endif
endif
augroup llama
autocmd!
autocmd InsertEnter * inoremap <expr> <silent> <C-F> llama#fim_inline(v:false)
autocmd InsertLeavePre * call llama#fim_cancel()
autocmd CursorMoved * call s:on_move()
autocmd CursorMovedI * call s:on_move()
autocmd CompleteChanged * call llama#fim_cancel()
if g:llama_config.auto_fim
autocmd CursorMovedI * call llama#fim(v:true)
endif
" gather chunks upon yanking
autocmd TextYankPost * if v:event.operator ==# 'y' | call s:pick_chunk(v:event.regcontents, v:false, v:true) | endif
" gather chunks upon entering/leaving a buffer
autocmd BufEnter * call timer_start(100, {-> s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)})
autocmd BufLeave * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)
" gather chunk upon saving the file
autocmd BufWritePost * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)
augroup END
silent! call llama#fim_cancel()
" init background update of the ring buffer
if g:llama_config.ring_n_chunks > 0
call s:ring_update()
endif
endfunction
" compute how similar two chunks of text are
" 0 - no similarity, 1 - high similarity
" TODO: figure out something better
function! s:chunk_sim(c0, c1)
let l:lines0 = len(a:c0)
let l:lines1 = len(a:c1)
let l:common = 0
for l:line0 in a:c0
for l:line1 in a:c1
if l:line0 == l:line1
let l:common += 1
break
endif
endfor
endfor
return 2.0 * l:common / (l:lines0 + l:lines1)
endfunction
" pick a random chunk of size g:llama_config.ring_chunk_size from the provided text and queue it for processing
"
" no_mod - do not pick chunks from buffers with pending changes
" do_evict - evict chunks that are very similar to the new one
"
function! s:pick_chunk(text, no_mod, do_evict)
" do not pick chunks from buffers with pending changes or buffers that are not files
if a:no_mod && (getbufvar(bufnr('%'), '&modified') || !buflisted(bufnr('%')) || !filereadable(expand('%')))
return
endif
" if the extra context option is disabled - do nothing
if g:llama_config.ring_n_chunks <= 0
return
endif
" don't pick very small chunks
if len(a:text) < 3
return
endif
if len(a:text) + 1 < g:llama_config.ring_chunk_size
let l:chunk = a:text
else
let l:l0 = s:rand(0, max([0, len(a:text) - g:llama_config.ring_chunk_size/2]))
let l:l1 = min([l:l0 + g:llama_config.ring_chunk_size/2, len(a:text)])
let l:chunk = a:text[l:l0:l:l1]
endif
let l:chunk_str = join(l:chunk, "\n") . "\n"
" check if this chunk is already added
let l:exist = v:false
for i in range(len(s:ring_chunks))
if s:ring_chunks[i].data == l:chunk
let l:exist = v:true
break
endif
endfor
for i in range(len(s:ring_queued))
if s:ring_queued[i].data == l:chunk
let l:exist = v:true
break
endif
endfor
if l:exist
return
endif
" evict queued chunks that are very similar to the new one
for i in range(len(s:ring_queued) - 1, 0, -1)
if s:chunk_sim(s:ring_queued[i].data, l:chunk) > 0.9
if a:do_evict
call remove(s:ring_queued, i)
let s:ring_n_evict += 1
else
return
endif
endif
endfor
" also from s:ring_chunks
for i in range(len(s:ring_chunks) - 1, 0, -1)
if s:chunk_sim(s:ring_chunks[i].data, l:chunk) > 0.9
if a:do_evict
call remove(s:ring_chunks, i)
let s:ring_n_evict += 1
else
return
endif
endif
endfor
" TODO: become parameter ?
if len(s:ring_queued) == 16
call remove(s:ring_queued, 0)
endif
call add(s:ring_queued, {'data': l:chunk, 'str': l:chunk_str, 'time': reltime(), 'filename': expand('%')})
"let &statusline = 'extra context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued)
endfunction
" picks a queued chunk, sends it for processing and adds it to s:ring_chunks
" called every g:llama_config.ring_update_ms
function! s:ring_update()
call timer_start(g:llama_config.ring_update_ms, {-> s:ring_update()})
" update only if in normal mode or if the cursor hasn't moved for a while
if mode() !=# 'n' && reltimefloat(reltime(s:t_last_move)) < 3.0
return
endif
if len(s:ring_queued) == 0
return
endif
" move the first queued chunk to the ring buffer
if len(s:ring_chunks) == g:llama_config.ring_n_chunks
call remove(s:ring_chunks, 0)
endif
call add(s:ring_chunks, remove(s:ring_queued, 0))
"let &statusline = 'updated context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued)
" send asynchronous job with the new extra context so that it is ready for the next FIM
let l:extra_context = []
for l:chunk in s:ring_chunks
call add(l:extra_context, {
\ 'text': l:chunk.str,
\ 'time': l:chunk.time,
\ 'filename': l:chunk.filename
\ })
endfor
" no samplers needed here
let l:request = json_encode({
\ 'input_prefix': "",
\ 'input_suffix': "",
\ 'input_extra': l:extra_context,
\ 'prompt': "",
\ 'n_predict': 1,
\ 'temperature': 0.0,
\ 'stream': v:false,
\ 'samplers': ["temperature"],
\ 'cache_prompt': v:true,
\ 't_max_prompt_ms': 1,
\ 't_max_predict_ms': 1
\ })
let l:curl_command = [
\ "curl",
\ "--silent",
\ "--no-buffer",
\ "--request", "POST",
\ "--url", g:llama_config.endpoint,
\ "--header", "Content-Type: application/json",
\ "--data", l:request
\ ]
" no callbacks because we don't need to process the response
if s:ghost_text_nvim
call jobstart(l:curl_command, {})
elseif s:ghost_text_vim
call job_start(l:curl_command, {})
endif
endfunction
" necessary for 'inoremap <expr>'
function! llama#fim_inline(is_auto) abort
call llama#fim(a:is_auto)
return ''
endfunction
" the main FIM call
" takes local context around the cursor and sends it together with the extra context to the server for completion
function! llama#fim(is_auto) abort
" we already have a suggestion for the current cursor position
if s:hint_shown && !a:is_auto
call llama#fim_cancel()
return
endif
call llama#fim_cancel()
" avoid sending repeated requests too fast
if reltimefloat(reltime(s:t_fim_start)) < 0.6
if s:timer_fim != -1
call timer_stop(s:timer_fim)
let s:timer_fim = -1
endif
let s:t_fim_start = reltime()
let s:timer_fim = timer_start(600, {-> llama#fim(v:true)})
return
endif
let s:t_fim_start = reltime()
let s:content = []
let s:can_accept = v:false
let s:pos_x = col('.') - 1
let s:pos_y = line('.')
let l:max_y = line('$')
let l:lines_prefix = getline(max([1, s:pos_y - g:llama_config.n_prefix]), s:pos_y - 1)
let l:lines_suffix = getline(s:pos_y + 1, min([l:max_y, s:pos_y + g:llama_config.n_suffix]))
let s:line_cur = getline('.')
let s:line_cur_prefix = strpart(s:line_cur, 0, s:pos_x)
let s:line_cur_suffix = strpart(s:line_cur, s:pos_x)
if a:is_auto && len(s:line_cur_suffix) > g:llama_config.max_line_suffix
return
endif
let l:prefix = ""
\ . join(l:lines_prefix, "\n")
\ . "\n"
let l:prompt = ""
\ . s:line_cur_prefix
let l:suffix = ""
\ . s:line_cur_suffix
\ . "\n"
\ . join(l:lines_suffix, "\n")
\ . "\n"
" prepare the extra context data
let l:extra_context = []
for l:chunk in s:ring_chunks
call add(l:extra_context, {
\ 'text': l:chunk.str,
\ 'time': l:chunk.time,
\ 'filename': l:chunk.filename
\ })
endfor
" the indentation of the current line
let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))
let l:request = json_encode({
\ 'input_prefix': l:prefix,
\ 'input_suffix': l:suffix,
\ 'input_extra': l:extra_context,
\ 'prompt': l:prompt,
\ 'n_predict': g:llama_config.n_predict,
\ 'n_indent': l:indent,
\ 'top_k': 40,
\ 'top_p': 0.99,
\ 'stream': v:false,
\ 'samplers': ["top_k", "top_p", "infill"],
\ 'cache_prompt': v:true,
\ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms,
\ 't_max_predict_ms': g:llama_config.t_max_predict_ms
\ })
let l:curl_command = [
\ "curl",
\ "--silent",
\ "--no-buffer",
\ "--request", "POST",
\ "--url", g:llama_config.endpoint,
\ "--header", "Content-Type: application/json",
\ "--data", l:request
\ ]
if s:current_job != v:null
if s:ghost_text_nvim
call jobstop(s:current_job)
elseif s:ghost_text_vim
call job_stop(s:current_job)
endif
endif
" send the request asynchronously
if s:ghost_text_nvim
let s:current_job = jobstart(l:curl_command, {
\ 'on_stdout': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
\ 'on_exit': function('s:fim_on_exit'),
\ 'stdout_buffered': v:true
\ })
elseif s:ghost_text_vim
let s:current_job = job_start(l:curl_command, {
\ 'out_cb': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
\ 'exit_cb': function('s:fim_on_exit')
\ })
endif
" TODO: per-file location
let l:delta_y = abs(s:pos_y - s:pos_y_pick)
" gather some extra context nearby and process it in the background
" only gather chunks if the cursor has moved a lot
" TODO: something more clever? reranking?
if a:is_auto && l:delta_y > 32
" expand the prefix even further
call s:pick_chunk(getline(max([1, s:pos_y - g:llama_config.ring_scope]), max([1, s:pos_y - g:llama_config.n_prefix])), v:false, v:false)
" pick a suffix chunk
call s:pick_chunk(getline(min([l:max_y, s:pos_y + g:llama_config.n_suffix]), min([l:max_y, s:pos_y + g:llama_config.n_suffix + g:llama_config.ring_chunk_size])), v:false, v:false)
let s:pos_y_pick = s:pos_y
endif
endfunction
" if first_line == v:true accept only the first line of the response
function! llama#fim_accept(first_line)
" insert the suggestion at the cursor location
if s:can_accept && len(s:content) > 0
call setline(s:pos_y, s:line_cur[:(s:pos_x - 1)] . s:content[0])
if len(s:content) > 1
if !a:first_line
call append(s:pos_y, s:content[1:-1])
endif
endif
" move the cursor to the end of the accepted text
if !a:first_line && len(s:content) > 1
call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx + 1)
else
call cursor(s:pos_y, s:pos_x + len(s:content[0]))
endif
endif
call llama#fim_cancel()
endfunction
function! llama#fim_cancel()
let s:hint_shown = v:false
" clear the virtual text
let l:bufnr = bufnr('%')
if s:ghost_text_nvim
let l:id_vt_fim = nvim_create_namespace('vt_fim')
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
elseif s:ghost_text_vim
call prop_remove({'type': s:hlgroup_hint, 'all': v:true})
call prop_remove({'type': s:hlgroup_info, 'all': v:true})
endif
" remove the mappings
silent! iunmap <buffer> <Tab>
silent! iunmap <buffer> <S-Tab>
silent! iunmap <buffer> <Esc>
endfunction
function! s:on_move()
let s:t_last_move = reltime()
call llama#fim_cancel()
endfunction
" callback that processes the FIM result from the server and displays the suggestion
function! s:fim_on_stdout(pos_x, pos_y, is_auto, job_id, data, event = v:null)
if s:ghost_text_nvim
let l:raw = join(a:data, "\n")
elseif s:ghost_text_vim
let l:raw = a:data
endif
if len(l:raw) == 0
return
endif
if a:pos_x != col('.') - 1 || a:pos_y != line('.')
return
endif
" show the suggestion only in insert mode
if mode() !=# 'i'
return
endif
let s:pos_x = a:pos_x
let s:pos_y = a:pos_y
let s:can_accept = v:true
let l:has_info = v:false
if s:can_accept && v:shell_error
if !a:is_auto
call add(s:content, "<| curl error: is the server on? |>")
endif
let s:can_accept = v:false
endif
let l:n_prompt = 0
let l:t_prompt_ms = 1.0
let l:s_prompt = 0
let l:n_predict = 0
let l:t_predict_ms = 1.0
let l:s_predict = 0
" get the generated suggestion
if s:can_accept
let l:response = json_decode(l:raw)
for l:part in split(get(l:response, 'content', ''), "\n", 1)
call add(s:content, l:part)
endfor
" remove trailing new lines
while len(s:content) > 0 && s:content[-1] == ""
call remove(s:content, -1)
endwhile
let l:generation_settings = get(l:response, 'generation_settings', {})
let l:n_ctx = get(l:generation_settings, 'n_ctx', 0)
let l:n_cached = get(l:response, 'tokens_cached', 0)
let l:truncated = get(l:response, 'truncated', v:false)
" if response.timings is available
if len(get(l:response, 'timings', {})) > 0
let l:has_info = v:true
let l:timings = get(l:response, 'timings', {})
let l:n_prompt = get(l:timings, 'prompt_n', 0)
let l:t_prompt_ms = get(l:timings, 'prompt_ms', 1)
let l:s_prompt = get(l:timings, 'prompt_per_second', 0)
let l:n_predict = get(l:timings, 'predicted_n', 0)
let l:t_predict_ms = get(l:timings, 'predicted_ms', 1)
let l:s_predict = get(l:timings, 'predicted_per_second', 0)
endif
endif
if len(s:content) == 0
call add(s:content, "")
let s:can_accept = v:false
endif
if len(s:content) == 0
return
endif
" NOTE: the following is logic for discarding predictions that repeat existing text
" the code is quite ugly and there is very likely a simpler and more canonical way to implement this
"
" still, I wonder if there is some better way that avoids having to do these special hacks?
" on one hand, the LLM 'sees' the contents of the file before we start editing, so it is normal that it would
" start generating whatever we have given it via the extra context. but on the other hand, it's not very
" helpful to re-generate the same code that is already there
" truncate the suggestion if the first line is empty
if len(s:content) == 1 && s:content[0] == ""
let s:content = [""]
endif
" ... and the next lines are repeated
if len(s:content) > 1 && s:content[0] == "" && s:content[1:] == getline(s:pos_y + 1, s:pos_y + len(s:content) - 1)
let s:content = [""]
endif
" truncate the suggestion if it repeats the suffix
if len(s:content) == 1 && s:content[0] == s:line_cur_suffix
let s:content = [""]
endif
" find the first non-empty line (strip whitespace)
let l:cmp_y = s:pos_y + 1
while l:cmp_y < line('$') && getline(l:cmp_y) =~? '^\s*$'
let l:cmp_y += 1
endwhile
if (s:line_cur_prefix . s:content[0]) == getline(l:cmp_y)
" truncate the suggestion if it repeats the next line
if len(s:content) == 1
let s:content = [""]
endif
" ... or if the second line of the suggestion is the prefix of line l:cmp_y + 1
if len(s:content) == 2 && s:content[-1] == getline(l:cmp_y + 1)[:len(s:content[-1]) - 1]
let s:content = [""]
endif
" ... or if the middle chunk of lines of the suggestion is the same as [l:cmp_y + 1, l:cmp_y + len(s:content) - 1)
if len(s:content) > 2 && join(s:content[1:-1], "\n") == join(getline(l:cmp_y + 1, l:cmp_y + len(s:content) - 1), "\n")
let s:content = [""]
endif
endif
" keep only lines that have the same or larger whitespace prefix as s:line_cur_prefix
"let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))
"for i in range(1, len(s:content) - 1)
" if strlen(matchstr(s:content[i], '^\s*')) < l:indent
" let s:content = s:content[:i - 1]
" break
" endif
"endfor
let s:pos_dx = len(s:content[-1])
let s:content[-1] .= s:line_cur_suffix
call llama#fim_cancel()
" display virtual text with the suggestion
let l:bufnr = bufnr('%')
if s:ghost_text_nvim
let l:id_vt_fim = nvim_create_namespace('vt_fim')
endif
" construct the info message
if g:llama_config.show_info > 0 && l:has_info
let l:prefix = ' '
if l:truncated
let l:info = printf("%s | WARNING: the context is full: %d / %d, increase the server context size or reduce g:llama_config.ring_n_chunks",
\ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim',
\ l:n_cached, l:n_ctx
\ )
else
let l:info = printf("%s | c: %d / %d, r: %d / %d, e: %d, q: %d / 16 | p: %d (%.2f ms, %.2f t/s) | g: %d (%.2f ms, %.2f t/s) | t: %.2f ms",
\ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim',
\ l:n_cached, l:n_ctx, len(s:ring_chunks), g:llama_config.ring_n_chunks, s:ring_n_evict, len(s:ring_queued),
\ l:n_prompt, l:t_prompt_ms, l:s_prompt,
\ l:n_predict, l:t_predict_ms, l:s_predict,
\ 1000.0 * reltimefloat(reltime(s:t_fim_start))
\ )
endif
if g:llama_config.show_info == 1
" display the info in the statusline
let &statusline = l:info
let l:info = ''
endif
endif
" display the suggestion and append the info to the end of the first line
if s:ghost_text_nvim
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
\ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']],
\ 'virt_text_win_col': virtcol('.') - 1
\ })
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, {
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}),
\ 'virt_text_win_col': virtcol('.')
\ })
elseif s:ghost_text_vim
let l:new_suffix = s:content[0]
if !empty(l:new_suffix)
call prop_add(s:pos_y, s:pos_x + 1, {
\ 'type': s:hlgroup_hint,
\ 'text': l:new_suffix
\ })
endif
for line in s:content[1:]
call prop_add(s:pos_y, 0, {
\ 'type': s:hlgroup_hint,
\ 'text': line,
\ 'text_padding_left': s:get_indent(line),
\ 'text_align': 'below'
\ })
endfor
if !empty(l:info)
call prop_add(s:pos_y, 0, {
\ 'type': s:hlgroup_info,
\ 'text': l:info,
\ 'text_padding_left': col('$'),
\ 'text_wrap': 'truncate'
\ })
endif
endif
" setup accept shortcuts
inoremap <buffer> <Tab> <C-O>:call llama#fim_accept(v:false)<CR>
inoremap <buffer> <S-Tab> <C-O>:call llama#fim_accept(v:true)<CR>
let s:hint_shown = v:true
endfunction
function! s:fim_on_exit(job_id, exit_code, event = v:null)
if a:exit_code != 0
echom "Job failed with exit code: " . a:exit_code
endif
let s:current_job = v:null
endfunction

View file

@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false; return false;
} }

View file

@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
return true; return true;
} }
struct llava_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_n_embd(llama_get_model(ctx_llama));
@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; float * embd = image_embed->embed+i*n_embd;
if (llama_decode(ctx_llama, batch)) { llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return false; return false;
} }

View file

@ -97,7 +97,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false; return false;
} }

View file

@ -89,8 +89,8 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us(); const auto t_enc_start = ggml_time_us();
// eval the prompt // eval the prompt
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
for (int s = 1; s < W + G + 1; ++s) { for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);

View file

@ -89,8 +89,8 @@ int main(int argc, char ** argv){
const auto t_enc_start = ggml_time_us(); const auto t_enc_start = ggml_time_us();
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();

View file

@ -529,7 +529,7 @@ int main(int argc, char ** argv) {
int enc_input_size = embd_inp.size(); int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data(); llama_token * enc_input_buf = embd_inp.data();
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) { if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return 1; return 1;
} }
@ -649,7 +649,7 @@ int main(int argc, char ** argv) {
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return 1; return 1;
} }

View file

@ -132,6 +132,7 @@ struct slot_params {
int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@ -174,6 +175,8 @@ struct server_slot {
std::vector<llama_token> prompt_tokens; std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens; std::vector<llama_token> extra_tokens;
size_t last_nl_pos = 0;
std::string generated_text; std::string generated_text;
std::vector<llama_token> cache_tokens; std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
@ -216,6 +219,7 @@ struct server_slot {
SLT_DBG(*this, "%s", "\n"); SLT_DBG(*this, "%s", "\n");
n_prompt_tokens = 0; n_prompt_tokens = 0;
last_nl_pos = 0;
generated_text = ""; generated_text = "";
has_new_line = false; has_new_line = false;
truncated = false; truncated = false;
@ -861,6 +865,7 @@ struct server_context {
slot.params.stream = json_value(data, "stream", false); slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false); slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
@ -879,7 +884,7 @@ struct server_context {
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed); slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
@ -1130,15 +1135,50 @@ struct server_context {
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
} }
if (slot.has_new_line) {
// if we have already seen a new line, we stop after a certain time limit // if we have already seen a new line, we stop after a certain time limit
if (slot.has_new_line && slot.params.t_max_predict_ms > 0 && if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
(ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stopped_limit = true; slot.stopped_limit = true;
slot.has_next_token = false; slot.has_next_token = false;
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
} }
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
if (slot.params.n_indent > 0) {
// check the current indentation
// TODO: improve by not doing it more than once for each new line
if (slot.last_nl_pos > 0) {
size_t pos = slot.last_nl_pos;
int n_indent = 0;
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
n_indent++;
pos++;
}
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
slot.stopped_limit = true;
slot.has_next_token = false;
// cut the last line
slot.generated_text.erase(pos, std::string::npos);
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
}
}
// find the next new line
{
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
if (pos != std::string::npos) {
slot.last_nl_pos = pos + 1;
}
}
}
}
// check if there is a new line in the generated text // check if there is a new line in the generated text
if (result.text_to_send.find('\n') != std::string::npos) { if (result.text_to_send.find('\n') != std::string::npos) {
slot.has_new_line = true; slot.has_new_line = true;
@ -2124,17 +2164,10 @@ struct server_context {
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
} }
common_sampler_reset(slot.smpl);
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt // reuse any previously computed tokens that are common with the new prompt
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens); slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
}
// reuse chunks from the cached prompt by shifting their KV cache in the new position // reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params.n_cache_reuse > 0) { if (params.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache size_t head_c = slot.n_past; // cache
@ -2167,8 +2200,6 @@ struct server_context {
for (size_t i = 0; i < n_match; i++) { for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
slot.n_past++; slot.n_past++;
} }
@ -2220,8 +2251,6 @@ struct server_context {
// there is no common part left // there is no common part left
slot.n_past = 0; slot.n_past = 0;
common_sampler_reset(slot.smpl);
} }
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
@ -2249,6 +2278,13 @@ struct server_context {
GGML_ASSERT(batch.n_tokens > 0); GGML_ASSERT(batch.n_tokens > 0);
common_sampler_reset(slot.smpl);
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
}
// extract the logits only for the last token // extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true; batch.logits[batch.n_tokens - 1] = true;
@ -2287,7 +2323,6 @@ struct server_context {
batch.n_seq_id + i, batch.n_seq_id + i,
batch.seq_id + i, batch.seq_id + i,
batch.logits + i, batch.logits + i,
0, 0, 0, // unused
}; };
const int ret = llama_decode(ctx, batch_view); const int ret = llama_decode(ctx, batch_view);

View file

@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
// prepare a batch for the prompt // prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size(), 0, 0); llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
// main loop // main loop
@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
fflush(stdout); fflush(stdout);
// prepare the next batch with the sampled token // prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1, n_pos, 0); batch = llama_batch_get_one(&new_token_id, 1);
n_decode += 1; n_decode += 1;
} }

25
ggml/include/ggml-amx.h Normal file
View file

@ -0,0 +1,25 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
// buffer_type API
GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
GGML_API bool ggml_backend_is_amx(ggml_backend_t backend);
// backend API
GGML_API ggml_backend_t ggml_backend_amx_init(void);
GGML_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);
GGML_API ggml_backend_reg_t ggml_backend_amx_reg(void);
#ifdef __cplusplus
}
#endif

View file

@ -34,6 +34,8 @@ extern "C" {
*/ */
#define GGML_CANN_MAX_DEVICES 16 #define GGML_CANN_MAX_DEVICES 16
GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
/** /**
* @brief Initializes the CANN backend for a specified device. * @brief Initializes the CANN backend for a specified device.
* *

View file

@ -19,6 +19,8 @@ extern "C" {
// backend API // backend API
GGML_API ggml_backend_t ggml_backend_sycl_init(int device); GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
GGML_API bool ggml_backend_is_sycl(ggml_backend_t backend);
// devide buffer // devide buffer
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device); GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
@ -29,14 +31,19 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const fl
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
GGML_API void ggml_backend_sycl_print_sycl_devices(void); GGML_API void ggml_backend_sycl_print_sycl_devices(void);
GGML_API void ggml_sycl_get_gpu_list(int *id_list, int max_len); GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
GGML_API void ggml_sycl_get_device_description(int device, char *description, size_t description_size); GGML_API void ggml_backend_sycl_get_device_description(int device,
char *description,
size_t description_size);
GGML_API int ggml_backend_sycl_get_device_count(); GGML_API int ggml_backend_sycl_get_device_count();
GGML_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); GGML_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
// SYCL doesn't support registering host memory, keep here for reference // SYCL doesn't support registering host memory, keep here for reference
// GGML_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size); // GGML_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
// GGML_API void ggml_backend_sycl_unregister_host_buffer(void * buffer); // GGML_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
GGML_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -2494,6 +2494,7 @@ extern "C" {
GGML_API int ggml_cpu_has_avx512_vbmi(void); GGML_API int ggml_cpu_has_avx512_vbmi(void);
GGML_API int ggml_cpu_has_avx512_vnni(void); GGML_API int ggml_cpu_has_avx512_vnni(void);
GGML_API int ggml_cpu_has_avx512_bf16(void); GGML_API int ggml_cpu_has_avx512_bf16(void);
GGML_API int ggml_cpu_has_amx_int8 (void);
GGML_API int ggml_cpu_has_fma (void); GGML_API int ggml_cpu_has_fma (void);
GGML_API int ggml_cpu_has_neon (void); GGML_API int ggml_cpu_has_neon (void);
GGML_API int ggml_cpu_has_sve (void); GGML_API int ggml_cpu_has_sve (void);

453
ggml/src/ggml-amx.cpp Normal file
View file

@ -0,0 +1,453 @@
#include "ggml-amx.h"
#include "ggml-amx/common.h"
#include "ggml-amx/mmq.h"
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#if defined(__gnu_linux__)
#include <sys/syscall.h>
#include <unistd.h>
#endif
#include <cstdlib>
#include <cstring>
#include <memory>
#if defined(__AMX_INT8__)
// AMX buffer interface
static const char * ggml_backend_amx_buffer_get_name(ggml_backend_buffer_t buffer) {
return "AMX";
GGML_UNUSED(buffer);
}
static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
free(buffer->context);
}
static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
return (void *)(buffer->context);
}
static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
memset((char *)tensor->data + offset, value, size);
GGML_UNUSED(buffer);
}
static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
if (qtype_has_amx_kernels(tensor->type)) {
ggml_backend_amx_convert_weight(tensor, data, offset, size);
} else {
memcpy((char *)tensor->data + offset, data, size);
}
GGML_UNUSED(buffer);
}
static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
memcpy(data, (const char *)tensor->data + offset, size);
GGML_UNUSED(buffer);
}
static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
if (ggml_backend_buffer_is_host(src->buffer)) {
if (qtype_has_amx_kernels(src->type)) {
ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
} else {
memcpy(dst->data, src->data, ggml_nbytes(src));
}
return true;
}
return false;
GGML_UNUSED(buffer);
}
static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
memset(buffer->context, value, buffer->size);
}
static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
/* .get_name = */ ggml_backend_amx_buffer_get_name,
/* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
/* .get_base = */ ggml_backend_amx_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
/* .clear = */ ggml_backend_amx_buffer_clear,
/* .reset = */ NULL,
};
static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "AMX";
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
if (data == NULL) {
fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
return NULL;
}
return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
}
static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return TENSOR_ALIGNMENT;
GGML_UNUSED(buft);
}
static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
return ggml_backend_amx_get_alloc_size(tensor);
GGML_UNUSED(buft);
}
static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return false;
GGML_UNUSED(buft);
}
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
/* .iface = */ {
/* .get_name = */ ggml_backend_amx_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
/* .is_host = */ ggml_backend_amx_buffer_type_is_host,
},
/* .device = */ NULL,
/* .context = */ NULL,
};
return &ggml_backend_buffer_type_amx;
}
// backend interface
static const char * ggml_backend_amx_name(ggml_backend_t backend) {
return "AMX";
GGML_UNUSED(backend);
}
static void ggml_backend_amx_free(ggml_backend_t backend) {
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
delete ctx;
delete backend;
}
static ggml_backend_buffer_type_t ggml_backend_amx_get_default_buffer_type(ggml_backend_t backend) {
return ggml_backend_amx_buffer_type();
GGML_UNUSED(backend);
}
static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
switch (node->op) {
case GGML_OP_MUL_MAT:
ggml_backend_amx_mul_mat(ctx, node);
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
break;
default:
fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
GGML_ASSERT(false);
}
}
return GGML_STATUS_SUCCESS;
GGML_UNUSED(backend);
}
static struct ggml_backend_i ggml_backend_amx_i = {
/* .get_name = */ ggml_backend_amx_name,
/* .free = */ ggml_backend_amx_free,
/* .get_default_buffer_type = */ ggml_backend_amx_get_default_buffer_type,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_amx_graph_compute,
/* .supports_op = */ NULL,
/* .supports_buft = */ NULL,
/* .offload_op = */ NULL,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
};
static ggml_guid_t ggml_backend_amx_guid() {
static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
return &guid;
}
#define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18
static bool ggml_amx_init() {
#if defined(__gnu_linux__)
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
fprintf(stderr, "AMX is not ready to be used!\n");
return false;
}
return true;
#elif defined(_WIN32)
return true;
#endif
}
ggml_backend_t ggml_backend_amx_init() {
// invoke a Linux system call to request access to AMX features
ggml_amx_init();
// backend context
ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
// ggml amx backend
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_amx_guid(),
/* .interface = */ ggml_backend_amx_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
/* .context = */ ctx,
};
return backend;
}
bool ggml_backend_is_amx(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
}
void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
GGML_ASSERT(ggml_backend_is_amx(backend_amx));
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
ctx->n_threads = n_threads;
}
// device interface
static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
return "AMX";
GGML_UNUSED(dev);
}
static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
return "Intel Advanced Matrix Extensions";
GGML_UNUSED(dev);
}
static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
// TODO
*free = 0;
*total = 0;
GGML_UNUSED(dev);
}
static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
return GGML_BACKEND_DEVICE_TYPE_CPU;
GGML_UNUSED(dev);
}
static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_amx_device_get_name(dev);
props->description = ggml_backend_amx_device_get_description(dev);
props->type = ggml_backend_amx_device_get_type(dev);
ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
// `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
props->caps = {
/* .async = */ false,
/* .host_buffer = */ false,
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
};
}
static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
return ggml_backend_amx_init();
GGML_UNUSED(dev);
GGML_UNUSED(params);
}
static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
return ggml_backend_amx_buffer_type();
GGML_UNUSED(dev);
}
static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
return true;
case GGML_OP_MUL_MAT: {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
const enum ggml_type type = src0->type;
const int64_t ne0 = op->ne[0];
bool is_training = src0->grad || src1->grad;
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
bool can_use_amx =
is_contiguous_2d(src0) && // src0 must be contiguous
is_contiguous_2d(src1) && // src1 must be contiguous
!is_training && // inference only
src1->type == GGML_TYPE_F32 && // src1 must be float32
has_amx_kernels && // with amx kernel impls
ne0 % (TILE_N * 2) == 0; // out_features is 32x
return can_use_amx;
}
default:
return false;
}
GGML_UNUSED(dev);
}
static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
GGML_UNUSED(dev);
}
static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
/* .get_name = */ ggml_backend_amx_device_get_name,
/* .get_description = */ ggml_backend_amx_device_get_description,
/* .get_memory = */ ggml_backend_amx_device_get_memory,
/* .get_type = */ ggml_backend_amx_device_get_type,
/* .get_props = */ ggml_backend_amx_device_get_props,
/* .init_backend = */ ggml_backend_amx_device_init,
/* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
/* .get_host_buffer_type = */ NULL,
/* .buffer_from_host_ptr = */ NULL,
/* .supports_op = */ ggml_backend_amx_device_supports_op,
/* .supports_buft = */ ggml_backend_amx_device_supports_buft,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};
// backend reg interface
static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
return "AMX";
GGML_UNUSED(reg);
}
static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
return 1;
GGML_UNUSED(reg);
}
static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
static ggml_backend_device ggml_backend_amx_device = {
/* .iface = */ ggml_backend_amx_device_i,
/* .reg = */ reg,
/* .context = */ nullptr,
};
return &ggml_backend_amx_device;
GGML_UNUSED(reg);
GGML_UNUSED(index);
}
static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
return (void *)ggml_backend_amx_set_n_threads;
}
return NULL;
GGML_UNUSED(reg);
GGML_UNUSED(name);
}
static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
/* .get_name = */ ggml_backend_amx_reg_get_name,
/* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
/* .get_device = */ ggml_backend_amx_reg_get_device,
/* .get_proc_address = */ ggml_backend_amx_get_proc_address,
};
ggml_backend_reg_t ggml_backend_amx_reg(void) {
static struct ggml_backend_reg ggml_backend_amx_reg = {
/* .iface = */ ggml_backend_amx_reg_i,
/* .context = */ NULL,
};
return &ggml_backend_amx_reg;
}
#else // if defined(__AMX_INT8__)
ggml_backend_t ggml_backend_amx_init(void) {
fprintf(stderr, "GGML is not compiled with AMX support!\n");
return ggml_backend_t{};
}
void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
fprintf(stderr, "GGML is not compiled with AMX support!\n");
GGML_UNUSED(backend_amx);
GGML_UNUSED(n_threads);
}
#endif

View file

@ -0,0 +1,93 @@
#pragma once
#include "ggml.h"
#include "ggml-cpu-impl.h" // <immintrin.h>
#include <algorithm>
#include <memory>
#include <type_traits>
#if defined(_OPENMP)
#include <omp.h>
#endif
#define TILE_M 16
#define TILE_N 16
#define TILE_K 32
#define VNNI_BLK 4
#define AMX_BLK_SIZE 32
#define TMM0 0
#define TMM1 1
#define TMM2 2
#define TMM3 3
#define TMM4 4
#define TMM5 5
#define TMM6 6
#define TMM7 7
// parallel routines
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y) { return (x + y - 1) / y; }
template <typename T>
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
#if 0
// onednn partition pattern
T& n_my = n_end;
if (nth <= 1 || n == 0) {
n_start = 0;
n_my = n;
} else {
T n1 = div_up(n, nth);
T n2 = n1 - 1;
T T1 = n - n2 * nth;
n_my = ith < T1 ? n1 : n2;
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
}
n_end += n_start;
#else
// pytorch aten partition pattern
T n_my = div_up(n, nth);
n_start = ith * n_my;
n_end = std::min(n_start + n_my, n);
#endif
}
template <typename func_t>
inline void parallel_for(int nth, int n, const func_t& f) {
#if defined(_OPENMP)
#pragma omp parallel num_threads(nth)
{
//int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#else
f(0, n);
GGML_UNUSED(nth);
#endif
}
// quantized types that have AMX support
inline bool qtype_has_amx_kernels(const enum ggml_type type) {
// TODO: fix padding for vnni format
return (type == GGML_TYPE_Q4_0) ||
(type == GGML_TYPE_Q4_1);
//(type == GGML_TYPE_Q8_0) ||
//(type == GGML_TYPE_Q4_K) ||
//(type == GGML_TYPE_Q5_K) ||
//(type == GGML_TYPE_Q6_K) ||
//(type == GGML_TYPE_IQ4_XS);
}
// ggml backend context
struct ggml_backend_amx_context {
int n_threads = GGML_DEFAULT_N_THREADS;
std::unique_ptr<char[]> work_data;
size_t work_size = 0;
};

2509
ggml/src/ggml-amx/mmq.cpp Normal file

File diff suppressed because it is too large Load diff

17
ggml/src/ggml-amx/mmq.h Normal file
View file

@ -0,0 +1,17 @@
#pragma once
#include "common.h"
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor);
void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst);
#ifdef __cplusplus
}
#endif

View file

@ -329,7 +329,6 @@ bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type
if (backend->device) { if (backend->device) {
return ggml_backend_dev_supports_buft(backend->device, buft); return ggml_backend_dev_supports_buft(backend->device, buft);
} }
return backend->iface.supports_buft(backend, buft); return backend->iface.supports_buft(backend, buft);
} }
@ -538,6 +537,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
#include "ggml-metal.h" #include "ggml-metal.h"
#endif #endif
#ifdef GGML_USE_SYCL
#include "ggml-sycl.h"
#endif
#ifdef GGML_USE_VULKAN #ifdef GGML_USE_VULKAN
#include "ggml-vulkan.h" #include "ggml-vulkan.h"
#endif #endif
@ -550,6 +553,18 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
#include "ggml-rpc.h" #include "ggml-rpc.h"
#endif #endif
#ifndef __AMX_INT8__
#undef GGML_USE_AMX
#endif
#ifdef GGML_USE_AMX
# include "ggml-amx.h"
#endif
#ifdef GGML_USE_CANN
#include "ggml-cann.h"
#endif
struct ggml_backend_registry { struct ggml_backend_registry {
std::vector<ggml_backend_reg_t> backends; std::vector<ggml_backend_reg_t> backends;
std::vector<ggml_backend_dev_t> devices; std::vector<ggml_backend_dev_t> devices;
@ -561,6 +576,9 @@ struct ggml_backend_registry {
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
register_backend(ggml_backend_metal_reg()); register_backend(ggml_backend_metal_reg());
#endif #endif
#ifdef GGML_USE_SYCL
register_backend(ggml_backend_sycl_reg());
#endif
#ifdef GGML_USE_VULKAN #ifdef GGML_USE_VULKAN
register_backend(ggml_backend_vk_reg()); register_backend(ggml_backend_vk_reg());
#endif #endif
@ -570,8 +588,14 @@ struct ggml_backend_registry {
#ifdef GGML_USE_RPC #ifdef GGML_USE_RPC
register_backend(ggml_backend_rpc_reg()); register_backend(ggml_backend_rpc_reg());
#endif #endif
#ifdef GGML_USE_AMX
register_backend(ggml_backend_amx_reg());
#endif
#ifdef GGML_USE_CANN
register_backend(ggml_backend_cann_reg());
#endif
// TODO: sycl, kompute, cann // TODO: kompute
register_backend(ggml_backend_cpu_reg()); register_backend(ggml_backend_cpu_reg());
} }
@ -2250,6 +2274,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
sched->backends[b] = backends[b]; sched->backends[b] = backends[b];
sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]); sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b])); GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
if (sched->n_copies > 1) { if (sched->n_copies > 1) {
for (int c = 0; c < sched->n_copies; c++) { for (int c = 0; c < sched->n_copies; c++) {
sched->events[b][c] = ggml_backend_event_new(backends[b]->device); sched->events[b][c] = ggml_backend_event_new(backends[b]->device);

View file

@ -39,6 +39,8 @@
#include "ggml-common.h" #include "ggml-common.h"
#define GGML_CANN_NAME "CANN"
/** /**
* @brief Handles CANN errors by printing an error message and aborting. * @brief Handles CANN errors by printing an error message and aborting.
* *
@ -851,13 +853,6 @@ static void ggml_backend_cann_buffer_set_tensor(
void *transform_buffer = malloc(size); void *transform_buffer = malloc(size);
ggml_backend_cann_transform(tensor, data, transform_buffer); ggml_backend_cann_transform(tensor, data, transform_buffer);
#ifndef NDEBUG
void *check_buffer = malloc(size);
ggml_backend_cann_transform_back(tensor, transform_buffer,
check_buffer);
GGML_ASSERT(memcmp(data, check_buffer, size) == 0);
free(check_buffer);
#endif
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
transform_buffer, size, transform_buffer, size,
ACL_MEMCPY_HOST_TO_DEVICE)); ACL_MEMCPY_HOST_TO_DEVICE));
@ -969,7 +964,7 @@ static void ggml_backend_cann_buffer_clear(
* This structure defines function pointers to operations that can be performed * This structure defines function pointers to operations that can be performed
* on a CANN buffer within the backend. * on a CANN buffer within the backend.
*/ */
static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
/* .get_name = */ ggml_backend_cann_buffer_get_name, /* .get_name = */ ggml_backend_cann_buffer_get_name,
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
/* .get_base = */ ggml_backend_cann_buffer_get_base, /* .get_base = */ ggml_backend_cann_buffer_get_base,
@ -1105,19 +1100,25 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
GGML_UNUSED(buft); GGML_UNUSED(buft);
} }
static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return false;
GGML_UNUSED(buft);
}
/** /**
* @brief Interface for managing CANN buffer types in the GGML backend. * @brief Interface for managing CANN buffer types in the GGML backend.
* *
* Provides function pointers for allocating, querying properties, and managing * Provides function pointers for allocating, querying properties, and managing
* memory for CANN buffer types in the GGML backend. * memory for CANN buffer types in the GGML backend.
*/ */
static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = { static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
/* .get_name = */ ggml_backend_cann_buffer_type_name, /* .get_name = */ ggml_backend_cann_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer, /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment, /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size, /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
/* .is_host = */ NULL, /* .is_host = */ ggml_backend_cann_buffer_type_is_host,
}; };
/** /**
@ -1148,7 +1149,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) { for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
ggml_backend_cann_buffer_types[i] = { ggml_backend_cann_buffer_types[i] = {
/* .iface = */ ggml_backend_cann_buffer_type_interface, /* .iface = */ ggml_backend_cann_buffer_type_interface,
/* .device = */ nullptr, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
/* .context = */ /* .context = */
new ggml_backend_cann_buffer_type_context{ new ggml_backend_cann_buffer_type_context{
i, "CANN" + std::to_string(i)}, i, "CANN" + std::to_string(i)},
@ -1264,7 +1265,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
}, },
/* .device = */ nullptr, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
/* .context = */ nullptr, /* .context = */ nullptr,
}; };
@ -1511,13 +1512,6 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
void *transform_buffer = malloc(size); void *transform_buffer = malloc(size);
ggml_backend_cann_transform(tensor, data, transform_buffer); ggml_backend_cann_transform(tensor, data, transform_buffer);
#ifndef NDEBUG
void *check_buffer = malloc(size);
ggml_backend_cann_transform_back(tensor, transform_buffer,
check_buffer);
GGML_ASSERT(memcmp(data, check_buffer, size));
free(check_buffer);
#endif
ACL_CHECK(aclrtMemcpyAsync( ACL_CHECK(aclrtMemcpyAsync(
(char *)tensor->data + offset, size, transform_buffer, size, (char *)tensor->data + offset, size, transform_buffer, size,
ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream())); ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
@ -1692,7 +1686,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
* @return bool Returns true if the operation is supported by the backend, * @return bool Returns true if the operation is supported by the backend,
* otherwise false. * otherwise false.
*/ */
static bool ggml_backend_cann_supports_op(ggml_backend_t backend, static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
const ggml_tensor* op) { const ggml_tensor* op) {
switch (op->op) { switch (op->op) {
case GGML_OP_UNARY: case GGML_OP_UNARY:
@ -1783,7 +1777,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
return false; return false;
} }
GGML_UNUSED(backend); GGML_UNUSED(dev);
} }
/** /**
@ -1801,31 +1795,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cann_buffer_type_name; return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
} }
/**
* @brief Checks if the CANN backend supports a specific backend buffer type.
*
* This function determines whether the CANN backend supports the given backend
* buffer type by comparing the device context of the backend and buffer type.
* It returns true if the devices are same between the backend context and
* buffer type context.
*
* @param backend Pointer to the CANN backend.
* @param buft Pointer to the backend buffer type to check.
* @return bool Returns true if the CANN backend supports the buffer type,
* otherwise false.
*/
static bool ggml_backend_cann_supports_buft(
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
if (ggml_backend_buft_is_cann(buft)) {
ggml_backend_cann_context * cann_ctx =
(ggml_backend_cann_context *)backend->context;
ggml_backend_cann_buffer_type_context * buft_ctx =
(ggml_backend_cann_buffer_type_context *)buft->context;
return buft_ctx->device == cann_ctx->device;
}
return false;
}
/** /**
* @brief Determines if a tensor operation should be offloaded to the CANN * @brief Determines if a tensor operation should be offloaded to the CANN
* backend. * backend.
@ -1840,54 +1809,14 @@ static bool ggml_backend_cann_supports_buft(
* @return bool Returns true if the operation should be offloaded, otherwise * @return bool Returns true if the operation should be offloaded, otherwise
* false. * false.
*/ */
static bool ggml_backend_cann_offload_op(ggml_backend_t backend, static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
const ggml_tensor* op) { const ggml_tensor* op) {
const int min_batch_size = 32; const int min_batch_size = 32;
GGML_UNUSED(backend); GGML_UNUSED(dev);
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
} }
/**
* @brief Creates a new event for the CANN backend.
*
* This function initializes a new event for the CANN backend by setting the
* device and creating an ACL runtime event. The created event is then wrapped
* in a ggml_backend_event structure and returned.
*
* @param backend Pointer to the CANN backend.
* @return ggml_backend_event_t Returns a pointer to the new event structure.
*/
static ggml_backend_event_t ggml_backend_cann_event_new(
ggml_backend_t backend) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
aclrtEvent event;
ACL_CHECK(aclrtCreateEvent(&event));
return new ggml_backend_event{
/* .device = */ nullptr,
/* .context = */ event,
};
}
/**
* @brief Frees a CANN backend event.
*
* This function destroys the ACL runtime event associated with the given CANN
* backend event and then deletes the event structure itself.
*
* @param event Pointer to the event structure to be freed.
*/
static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
delete event;
}
/** /**
* @brief Records an event on the CANN backend stream. * @brief Records an event on the CANN backend stream.
* *
@ -1924,17 +1853,6 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
} }
} }
/**
* @brief Synchronizes the given event on the CANN backend.
*
* This function waits for the specified event to complete on the ACL runtime.
*
* @param event Pointer to the event structure to be synchronized.
*/
static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
}
/** /**
* @brief Structure defining the interface for the CANN backend. * @brief Structure defining the interface for the CANN backend.
* *
@ -1942,7 +1860,7 @@ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
* supported by the CANN backend, including name retrieval, memory * supported by the CANN backend, including name retrieval, memory
* management, tensor operations, synchronization, and event handling. * management, tensor operations, synchronization, and event handling.
*/ */
static ggml_backend_i ggml_backend_cann_interface = { static const ggml_backend_i ggml_backend_cann_interface = {
/* .get_name = */ ggml_backend_cann_name, /* .get_name = */ ggml_backend_cann_name,
/* .free = */ ggml_backend_cann_free, /* .free = */ ggml_backend_cann_free,
/* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type, /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
@ -1955,9 +1873,9 @@ static ggml_backend_i ggml_backend_cann_interface = {
/* .graph_plan_update = */ NULL, /* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL, /* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_cann_graph_compute, /* .graph_compute = */ ggml_backend_cann_graph_compute,
/* .supports_op = */ ggml_backend_cann_supports_op, /* .supports_op = */ NULL, // moved to device
/* .supports_buft = */ ggml_backend_cann_supports_buft, /* .supports_buft = */ NULL, // moved to device
/* .offload_op = */ ggml_backend_cann_offload_op, /* .offload_op = */ NULL, // moved to device
/* .event_record = */ ggml_backend_cann_event_record, /* .event_record = */ ggml_backend_cann_event_record,
/* .event_wait = */ ggml_backend_cann_event_wait, /* .event_wait = */ ggml_backend_cann_event_wait,
}; };
@ -1976,6 +1894,234 @@ static ggml_guid_t ggml_backend_cann_guid() {
return &guid; return &guid;
} }
// backend device
struct ggml_backend_cann_device_context {
int device;
std::string name;
std::string description;
};
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
return ctx->name.c_str();
}
static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
return ctx->description.c_str();
}
static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
ggml_backend_cann_get_device_memory(ctx->device, free, total);
}
static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
GGML_UNUSED(dev);
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
}
static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_cann_device_get_name(dev);
props->description = ggml_backend_cann_device_get_description(dev);
props->type = ggml_backend_cann_device_get_type(dev);
ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
props->caps = {
/* .async = */ false,
/* .host_buffer = */ host_buffer,
/* .buffer_from_host_ptr = */ false,
/* .events = */ true,
};
}
static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
return ggml_backend_cann_init(ctx->device);
}
/**
* @brief Checks if the CANN backend supports a specific backend buffer type.
*
* This function determines whether the CANN backend supports the given backend
* buffer type by comparing the device context of the backend and buffer type.
* It returns true if the devices are same between the backend context and
* buffer type context.
*
* @param backend Pointer to the CANN backend.
* @param buft Pointer to the backend buffer type to check.
* @return bool Returns true if the CANN backend supports the buffer type,
* otherwise false.
*/
static bool ggml_backend_cann_supports_buft(
ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
if (ggml_backend_buft_is_cann(buft)) {
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
ggml_backend_cann_buffer_type_context * buft_ctx =
(ggml_backend_cann_buffer_type_context *)buft->context;
return buft_ctx->device == dev_ctx->device;
}
return false;
}
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
return ggml_backend_cann_buffer_type(ctx->device);
}
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
GGML_UNUSED(dev);
return ggml_backend_cann_host_buffer_type();
}
/**
* @brief Creates a new event for the CANN backend device.
*
* This function initializes a new event for the CANN backend by setting the
* device and creating an ACL runtime event. The created event is then wrapped
* in a ggml_backend_event structure and returned.
*
* @param backend Pointer to the CANN backend.
* @return ggml_backend_event_t Returns a pointer to the new event structure.
*/
static ggml_backend_event_t ggml_backend_cann_device_event_new(
ggml_backend_dev_t dev) {
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
ggml_cann_set_device(dev_ctx->device);
aclrtEvent event;
ACL_CHECK(aclrtCreateEvent(&event));
return new ggml_backend_event{
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
/* .context = */ event,
};
}
/**
* @brief Frees a CANN backend event.
*
* This function destroys the ACL runtime event associated with the given CANN
* backend event and then deletes the event structure itself.
*
* @param event Pointer to the event structure to be freed.
*/
static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
delete event;
GGML_UNUSED(dev);
}
/**
* @brief Synchronizes the given event on the CANN backend.
*
* This function waits for the specified event to complete on the ACL runtime.
*
* @param event Pointer to the event structure to be synchronized.
*/
static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
GGML_UNUSED(dev);
}
static const ggml_backend_device_i ggml_backend_cann_device_interface = {
/* .get_name = */ ggml_backend_cann_device_get_name,
/* .get_description = */ ggml_backend_cann_device_get_description,
/* .get_memory = */ ggml_backend_cann_device_get_memory,
/* .get_type = */ ggml_backend_cann_device_get_type,
/* .get_props = */ ggml_backend_cann_device_get_props,
/* .init_backend = */ ggml_backend_cann_device_init, // called for every card
/* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
/* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
/* .buffer_from_host_ptr = */ NULL, // not supported for CANN
/* .supports_op = */ ggml_backend_cann_supports_op,
/* .supports_buft = */ ggml_backend_cann_supports_buft,
/* .offload_op = */ ggml_backend_cann_offload_op,
/* .event_new = */ ggml_backend_cann_device_event_new,
/* .event_free = */ ggml_backend_cann_device_event_free,
/* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
};
// backend reg
struct ggml_backend_cann_reg_context {
std::vector<ggml_backend_dev_t> devices;
};
static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
GGML_UNUSED(reg);
return GGML_CANN_NAME;
}
static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
return ctx->devices.size();
}
static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
GGML_ASSERT(index < ctx->devices.size());
return ctx->devices[index];
}
static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
GGML_UNUSED(reg);
GGML_UNUSED(name);
// reserved for future use
return nullptr;
}
static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
/* .get_name = */ ggml_backend_cann_reg_get_name,
/* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
/* .get_device_get = */ ggml_backend_cann_reg_get_device,
/* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
};
// backend registry, called only once for cann backend
ggml_backend_reg_t ggml_backend_cann_reg() {
static ggml_backend_reg reg;
static bool initialized = false;
{
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
aclInit(nullptr);
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
for (int i = 0; i < ggml_cann_info().device_count; i++) {
ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
dev_ctx->description = aclrtGetSocName();
dev_ctx->device = i;
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
ggml_cann_set_device(i);
ggml_backend_dev_t dev = new ggml_backend_device {
/* .interface = */ ggml_backend_cann_device_interface,
/* .reg = */ &reg,
/* .context = */ dev_ctx
};
ctx->devices.push_back(dev);
}
reg = ggml_backend_reg {
/* .interface = */ ggml_backend_cann_reg_interface,
/* .context = */ ctx
};
}
initialized = true;
}
return &reg;
}
ggml_backend_t ggml_backend_cann_init(int32_t device) { ggml_backend_t ggml_backend_cann_init(int32_t device) {
aclInit(nullptr); aclInit(nullptr);
if (device < 0 || device >= ggml_backend_cann_get_device_count()) { if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
@ -1992,7 +2138,7 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) {
ggml_backend_t cann_backend = ggml_backend_t cann_backend =
new ggml_backend{/* .guid = */ ggml_backend_cann_guid(), new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
/* .interface = */ ggml_backend_cann_interface, /* .interface = */ ggml_backend_cann_interface,
/* .device = */ nullptr, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
/* .context = */ ctx}; /* .context = */ ctx};
return cann_backend; return cann_backend;

View file

@ -1151,7 +1151,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer)); GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
char * src_ptr = (char *) src->data; const char * src_ptr = (const char *) src->data;
char * dst_ptr = (char *) dst; char * dst_ptr = (char *) dst;
const int64_t ne0 = src->ne[0]; const int64_t ne0 = src->ne[0];
@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const enum ggml_type type = src->type; const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type); const int64_t ts = ggml_type_size(type);
const int64_t bs = ggml_blck_size(type); const int64_t bs = ggml_blck_size(type);
int64_t i1_diff = i1_high - i1_low; const int64_t i1_diff = i1_high - i1_low;
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) { if (nb0 == ts && nb1 == ts*ne0/bs) {
@ -1479,13 +1479,18 @@ static void ggml_cuda_op_mul_mat(
if (src0_is_contiguous) { if (src0_is_contiguous) {
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data; dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
} else { } else {
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0)); // If src0 is not contiguous it will be copied to a temporary buffer.
// This buffer needs to be cleared entirely because multiple regions will function as padding.
const size_t nbytes_data = ggml_nbytes(src0);
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
} }
// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared: // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream)); CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
} }
@ -3145,7 +3150,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ROPE: case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_SUM: case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:

View file

@ -92,8 +92,8 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t OW = dst->ne[1]; const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[3]; const int64_t batch = src1->ne[is_2D ? 3 : 2];
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16) { if(dst->type == GGML_TYPE_F16) {
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);

View file

@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t nb01 = src0->nb[1];
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1]; const int64_t ne11 = src1->ne[1];
GGML_ASSERT(ne10 % QK8_1 == 0); GGML_ASSERT(ne10 % QK8_1 == 0);
@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low; const int64_t row_diff = row_high - row_low;
const int64_t stride00 = nb01 / ggml_type_size(src0->type); const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
int id = ggml_cuda_get_device(); int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc; const int compute_capability = ggml_cuda_info().devices[id].cc;

View file

@ -241,6 +241,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32, GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32, GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@ -272,6 +274,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_COUNT GGML_METAL_KERNEL_TYPE_COUNT
}; };
@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
} }
[metal_library release]; [metal_library release];
@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16; return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D: case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D:
return false; return false;
case GGML_OP_POOL_2D:
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
} break; } break;
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
{ {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
switch (dst->type) { switch (dst->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; case GGML_TYPE_F32: {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; pipeline = (is_gt_mttpt ?
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
:
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
} break;
case GGML_TYPE_F16: {
pipeline = (is_gt_mttpt ?
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
:
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
} break;
default: GGML_ABORT("fatal error"); default: GGML_ABORT("fatal error");
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
if (is_gt_mttpt) {
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} else {
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
}
} break; } break;
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
{ {
@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_POOL_2D:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
const int32_t * opts = dst->op_params;
enum ggml_op_pool op = opts[0];
id<MTLComputePipelineState> pipeline = nil;
switch (src0t) {
case GGML_TYPE_F32: {
switch(op) {
case GGML_OP_POOL_AVG:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
case GGML_OP_POOL_MAX:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}
} break;
default: GGML_ASSERT(false && "not implemented");
}
const int32_t k0 = opts[1];
const int32_t k1 = opts[2];
const int32_t s0 = opts[3];
const int32_t s1 = opts[4];
const int32_t p0 = opts[5];
const int32_t p1 = opts[6];
const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int64_t parallel_elements = N * OC * OH * OW;
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
[encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
default: default:
{ {
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

View file

@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>; template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
typedef void (im2col_ext_t)(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
constant int32_t & N,
constant int32_t & KH,
constant int32_t & KW,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col_ext(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
constant int32_t & N,
constant int32_t & KH,
constant int32_t & KW,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
const int32_t d = tgpig[0] / CHW;
const int32_t chw = tgpig[0] % CHW;
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int32_t HW = tgpig[0] % KHW;
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= N) {
return;
}
const int32_t tpitg_1 = HW / KW;
const int32_t tpitg_2 = HW % KW;
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
const int32_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
kernel void kernel_upscale_f32( kernel void kernel_upscale_f32(
device const char * src0, device const char * src0,
device char * dst, device char * dst,
@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
kernel void kernel_pool_2d_max_f32(
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {
if (gid >= parallel_elements) {
return;
}
const int idx = gid;
const int I_HW = IH * IW;
const int O_HW = OH * OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * IW + j]);
}
}
o_ptr[cur_oh * OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {
if (gid >= parallel_elements) {
return;
}
const int idx = gid;
const int I_HW = IH * IW;
const int O_HW = OH * OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (k0 * k1);
float res = 0;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * IW + j];
res += cur * scale;
}
}
o_ptr[cur_oh * OW + cur_ow] = res;
}

View file

@ -57,8 +57,9 @@ struct socket_t {
} }
}; };
// ggml_tensor is serialized into rpc_tensor // all RPC structures must be packed
#pragma pack(push, 1) #pragma pack(push, 1)
// ggml_tensor is serialized into rpc_tensor
struct rpc_tensor { struct rpc_tensor {
uint64_t id; uint64_t id;
uint32_t type; uint32_t type;
@ -76,7 +77,6 @@ struct rpc_tensor {
char padding[4]; char padding[4];
}; };
#pragma pack(pop)
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8"); static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
@ -96,6 +96,65 @@ enum rpc_cmd {
RPC_CMD_COUNT, RPC_CMD_COUNT,
}; };
struct rpc_msg_alloc_buffer_req {
uint64_t size;
};
struct rpc_msg_alloc_buffer_rsp {
uint64_t remote_ptr;
uint64_t remote_size;
};
struct rpc_msg_get_alignment_rsp {
uint64_t alignment;
};
struct rpc_msg_get_max_size_rsp {
uint64_t max_size;
};
struct rpc_msg_buffer_get_base_req {
uint64_t remote_ptr;
};
struct rpc_msg_buffer_get_base_rsp {
uint64_t base_ptr;
};
struct rpc_msg_free_buffer_req {
uint64_t remote_ptr;
};
struct rpc_msg_buffer_clear_req {
uint64_t remote_ptr;
uint8_t value;
};
struct rpc_msg_get_tensor_req {
rpc_tensor tensor;
uint64_t offset;
uint64_t size;
};
struct rpc_msg_copy_tensor_req {
rpc_tensor src;
rpc_tensor dst;
};
struct rpc_msg_copy_tensor_rsp {
uint8_t result;
};
struct rpc_msg_graph_compute_rsp {
uint8_t result;
};
struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
};
#pragma pack(pop)
// RPC data structures // RPC data structures
static ggml_guid_t ggml_backend_rpc_guid() { static ggml_guid_t ggml_backend_rpc_guid() {
@ -240,6 +299,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
return true; return true;
} }
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
return false;
}
return send_data(sockfd, msg, msg_size);
}
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
uint64_t size;
if (!recv_data(sockfd, &size, sizeof(size))) {
return false;
}
if (size != msg_size) {
return false;
}
return recv_data(sockfd, msg, msg_size);
}
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
uint64_t size;
if (!recv_data(sockfd, &size, sizeof(size))) {
return false;
}
try {
input.resize(size);
} catch (const std::bad_alloc & e) {
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
return false;
}
return recv_data(sockfd, input.data(), size);
}
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
size_t pos = endpoint.find(':'); size_t pos = endpoint.find(':');
if (pos == std::string::npos) { if (pos == std::string::npos) {
@ -252,28 +343,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
uint8_t cmd_byte = cmd; uint8_t cmd_byte = cmd;
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
return false; return false;
} }
uint64_t input_size = input.size();
if (!send_data(sock->fd, &input_size, sizeof(input_size))) { if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
return false; return false;
} }
if (!send_data(sock->fd, input.data(), input.size())) { if (!send_data(sock->fd, input, input_size)) {
return false; return false;
} }
uint64_t output_size; // TODO: currently the output_size is always known, do we need support for commands with variable output size?
if (!recv_data(sock->fd, &output_size, sizeof(output_size))) { // even if we do, we can skip sending output_size from the server for commands with known output size
uint64_t out_size;
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
return false; return false;
} }
if (output_size == 0) { if (out_size != output_size) {
output.clear(); return false;
return true;
} }
output.resize(output_size); if (!recv_data(sock->fd, output, output_size)) {
if (!recv_data(sock->fd, output.data(), output_size)) {
return false; return false;
} }
return true; return true;
@ -326,14 +416,9 @@ static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffe
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
// input serialization format: | remote_ptr (8 bytes) | rpc_msg_free_buffer_req request = {ctx->remote_ptr};
std::vector<uint8_t> input(sizeof(uint64_t), 0); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.empty());
delete ctx; delete ctx;
} }
@ -342,20 +427,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) { if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
return ctx->base_cache[buffer]; return ctx->base_cache[buffer];
} }
// input serialization format: | remote_ptr (8 bytes) | rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
std::vector<uint8_t> input(sizeof(uint64_t), 0); rpc_msg_buffer_get_base_rsp response;
uint64_t remote_ptr = ctx->remote_ptr; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t)); void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
// output serialization format: | base_ptr (8 bytes) | ctx->base_cache[buffer] = base_ptr;
uint64_t base_ptr; return base_ptr;
memcpy(&base_ptr, output.data(), sizeof(base_ptr));
void * base = reinterpret_cast<void *>(base_ptr);
ctx->base_cache[buffer] = base;
return base;
} }
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@ -405,26 +483,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
std::vector<uint8_t> output; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
} }
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
// input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) | rpc_msg_get_tensor_req request;
int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t); request.tensor = serialize_tensor(tensor);
std::vector<uint8_t> input(input_size, 0); request.offset = offset;
rpc_tensor rpc_tensor = serialize_tensor(tensor); request.size = size;
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == size);
// output serialization format: | data (size bytes) |
memcpy(data, output.data(), size);
} }
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@ -437,30 +507,19 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
return false; return false;
} }
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
// input serialization format: | rpc_tensor src | rpc_tensor dst | rpc_msg_copy_tensor_req request;
int input_size = 2*sizeof(rpc_tensor); request.src = serialize_tensor(src);
std::vector<uint8_t> input(input_size, 0); request.dst = serialize_tensor(dst);
rpc_tensor rpc_src = serialize_tensor(src); rpc_msg_copy_tensor_rsp response;
rpc_tensor rpc_dst = serialize_tensor(dst); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
// output serialization format: | result (1 byte) | return response.result;
GGML_ASSERT(output.size() == 1);
return output[0];
} }
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
// serialization format: | bufptr (8 bytes) | value (1 byte) | rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
int input_size = sizeof(uint64_t) + sizeof(uint8_t); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
} }
@ -484,25 +543,16 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
// input serialization format: | size (8 bytes) | rpc_msg_alloc_buffer_req request = {size};
int input_size = sizeof(uint64_t); rpc_msg_alloc_buffer_rsp response;
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output;
auto sock = get_socket(buft_ctx->endpoint); auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); if (response.remote_ptr != 0) {
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
uint64_t remote_ptr;
memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
size_t remote_size;
memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
if (remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface, ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"}, new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
remote_size); response.remote_size);
return buffer; return buffer;
} else { } else {
return nullptr; return nullptr;
@ -510,16 +560,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
} }
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) { static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes | rpc_msg_get_alignment_rsp response;
std::vector<uint8_t> input; bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t)); return response.alignment;
// output serialization format: | alignment (8 bytes) |
uint64_t alignment;
memcpy(&alignment, output.data(), sizeof(alignment));
return alignment;
} }
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@ -528,16 +572,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
} }
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) { static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes | rpc_msg_get_max_size_rsp response;
std::vector<uint8_t> input; bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t)); return response.max_size;
// output serialization format: | max_size (8 bytes) |
uint64_t max_size;
memcpy(&max_size, output.data(), sizeof(max_size));
return max_size;
} }
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
@ -622,12 +660,11 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector<uint8_t> input; std::vector<uint8_t> input;
serialize_graph(cgraph, input); serialize_graph(cgraph, input);
std::vector<uint8_t> output; rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint); auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1); return (enum ggml_status)response.result;
return (enum ggml_status)output[0];
} }
static ggml_backend_i ggml_backend_rpc_interface = { static ggml_backend_i ggml_backend_rpc_interface = {
@ -702,19 +739,11 @@ GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
} }
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) { static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
// input serialization format: | 0 bytes | rpc_msg_get_device_memory_rsp response;
std::vector<uint8_t> input; bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); *free = response.free_mem;
// output serialization format: | free (8 bytes) | total (8 bytes) | *total = response.total_mem;
uint64_t free_mem;
memcpy(&free_mem, output.data(), sizeof(free_mem));
uint64_t total_mem;
memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
*free = free_mem;
*total = total_mem;
} }
GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
@ -734,16 +763,16 @@ public:
rpc_server(ggml_backend_t backend) : backend(backend) {} rpc_server(ggml_backend_t backend) : backend(backend) {}
~rpc_server(); ~rpc_server();
bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
void get_alignment(std::vector<uint8_t> & output); void get_alignment(rpc_msg_get_alignment_rsp & response);
void get_max_size(std::vector<uint8_t> & output); void get_max_size(rpc_msg_get_max_size_rsp & response);
bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
bool free_buffer(const std::vector<uint8_t> & input); bool free_buffer(const rpc_msg_free_buffer_req & request);
bool buffer_clear(const std::vector<uint8_t> & input); bool buffer_clear(const rpc_msg_buffer_clear_req & request);
bool set_tensor(const std::vector<uint8_t> & input); bool set_tensor(const std::vector<uint8_t> & input);
bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
private: private:
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@ -757,80 +786,50 @@ private:
std::unordered_set<ggml_backend_buffer_t> buffers; std::unordered_set<ggml_backend_buffer_t> buffers;
}; };
bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
// input serialization format: | size (8 bytes) |
if (input.size() != sizeof(uint64_t)) {
return false;
}
uint64_t size;
memcpy(&size, input.data(), sizeof(size));
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
uint64_t remote_ptr = 0; response.remote_ptr = 0;
uint64_t remote_size = 0; response.remote_size = 0;
if (buffer != nullptr) { if (buffer != nullptr) {
remote_ptr = reinterpret_cast<uint64_t>(buffer); response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
remote_size = buffer->size; response.remote_size = buffer->size;
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size); GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
buffers.insert(buffer); buffers.insert(buffer);
} else { } else {
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size); GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
} }
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
return true;
} }
void rpc_server::get_alignment(std::vector<uint8_t> & output) { void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
size_t alignment = ggml_backend_buft_get_alignment(buft); size_t alignment = ggml_backend_buft_get_alignment(buft);
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment); GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
// output serialization format: | alignment (8 bytes) | response.alignment = alignment;
output.resize(sizeof(uint64_t), 0);
memcpy(output.data(), &alignment, sizeof(alignment));
} }
void rpc_server::get_max_size(std::vector<uint8_t> & output) { void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
size_t max_size = ggml_backend_buft_get_max_size(buft); size_t max_size = ggml_backend_buft_get_max_size(buft);
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size); GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
// output serialization format: | max_size (8 bytes) | response.max_size = max_size;
output.resize(sizeof(uint64_t), 0);
memcpy(output.data(), &max_size, sizeof(max_size));
} }
bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
// input serialization format: | remote_ptr (8 bytes) | GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
if (input.size() != sizeof(uint64_t)) { ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
return false;
}
uint64_t remote_ptr;
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
if (buffers.find(buffer) == buffers.end()) { if (buffers.find(buffer) == buffers.end()) {
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
return false; return false;
} }
void * base = ggml_backend_buffer_get_base(buffer); void * base = ggml_backend_buffer_get_base(buffer);
// output serialization format: | base_ptr (8 bytes) | response.base_ptr = reinterpret_cast<uint64_t>(base);
uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
output.resize(sizeof(uint64_t), 0);
memcpy(output.data(), &base_ptr, sizeof(base_ptr));
return true; return true;
} }
bool rpc_server::free_buffer(const std::vector<uint8_t> & input) { bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
// input serialization format: | remote_ptr (8 bytes) | GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
if (input.size() != sizeof(uint64_t)) { ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
return false;
}
uint64_t remote_ptr;
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
if (buffers.find(buffer) == buffers.end()) { if (buffers.find(buffer) == buffers.end()) {
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
return false; return false;
@ -840,22 +839,14 @@ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
return true; return true;
} }
bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) { bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
// input serialization format: | remote_ptr (8 bytes) | value (1 byte) | GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) { ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
return false;
}
uint64_t remote_ptr;
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
uint8_t value;
memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
if (buffers.find(buffer) == buffers.end()) { if (buffers.find(buffer) == buffers.end()) {
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
return false; return false;
} }
ggml_backend_buffer_clear(buffer, value); ggml_backend_buffer_clear(buffer, request.value);
return true; return true;
} }
@ -930,74 +921,55 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
return true; return true;
} }
bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
// serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
return false;
}
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
uint64_t offset;
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
uint64_t size;
memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
struct ggml_init_params params { struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(), /*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
}; };
struct ggml_context * ctx = ggml_init(params); struct ggml_context * ctx = ggml_init(params);
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
if (tensor == nullptr) { if (tensor == nullptr) {
GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
ggml_free(ctx); ggml_free(ctx);
return false; return false;
} }
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
// sanitize tensor->data // sanitize tensor->data
{ {
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { if (request.tensor.data + request.offset < p0 ||
request.tensor.data + request.offset >= p1 ||
request.size > (p1 - request.tensor.data - request.offset)) {
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__); GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
} }
} }
// output serialization format: | data (size bytes) | response.resize(request.size, 0);
output.resize(size, 0); ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
ggml_backend_tensor_get(tensor, output.data(), offset, size);
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }
bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
// serialization format: | rpc_tensor src | rpc_tensor dst |
if (input.size() != 2*sizeof(rpc_tensor)) {
return false;
}
const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
struct ggml_init_params params { struct ggml_init_params params {
/*.mem_size =*/ 2*ggml_tensor_overhead(), /*.mem_size =*/ 2*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
}; };
struct ggml_context * ctx = ggml_init(params); struct ggml_context * ctx = ggml_init(params);
ggml_tensor * src = deserialize_tensor(ctx, rpc_src); ggml_tensor * src = deserialize_tensor(ctx, &request.src);
ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
if (src == nullptr || dst == nullptr) { if (src == nullptr || dst == nullptr) {
GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__); GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
ggml_free(ctx); ggml_free(ctx);
return false; return false;
} }
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer); GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
bool result = ggml_backend_buffer_copy_tensor(src, dst); response.result = ggml_backend_buffer_copy_tensor(src, dst);
// output serialization format: | result (1 byte) |
output.resize(1, 0);
output[0] = result;
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }
@ -1026,7 +998,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
return result; return result;
} }
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
// serialization format: // serialization format:
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < sizeof(uint32_t)) { if (input.size() < sizeof(uint32_t)) {
@ -1066,9 +1038,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map); graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
} }
ggml_status status = ggml_backend_graph_compute(backend, graph); ggml_status status = ggml_backend_graph_compute(backend, graph);
// output serialization format: | status (1 byte) | response.result = status;
output.resize(1, 0);
output[0] = status;
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }
@ -1091,85 +1061,153 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
fprintf(stderr, "Unknown command: %d\n", cmd); fprintf(stderr, "Unknown command: %d\n", cmd);
break; break;
} }
std::vector<uint8_t> input;
std::vector<uint8_t> output;
uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break;
}
try {
input.resize(input_size);
} catch (const std::bad_alloc & e) {
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
break;
}
if (!recv_data(sockfd, input.data(), input_size)) {
break;
}
bool ok = true;
switch (cmd) { switch (cmd) {
case RPC_CMD_ALLOC_BUFFER: { case RPC_CMD_ALLOC_BUFFER: {
ok = server.alloc_buffer(input, output); rpc_msg_alloc_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_alloc_buffer_rsp response;
server.alloc_buffer(request, response);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break; break;
} }
case RPC_CMD_GET_ALIGNMENT: { case RPC_CMD_GET_ALIGNMENT: {
server.get_alignment(output); if (!recv_msg(sockfd, nullptr, 0)) {
return;
}
rpc_msg_get_alignment_rsp response;
server.get_alignment(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break; break;
} }
case RPC_CMD_GET_MAX_SIZE: { case RPC_CMD_GET_MAX_SIZE: {
server.get_max_size(output); if (!recv_msg(sockfd, nullptr, 0)) {
return;
}
rpc_msg_get_max_size_rsp response;
server.get_max_size(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break; break;
} }
case RPC_CMD_BUFFER_GET_BASE: { case RPC_CMD_BUFFER_GET_BASE: {
ok = server.buffer_get_base(input, output); rpc_msg_buffer_get_base_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_buffer_get_base_rsp response;
if (!server.buffer_get_base(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break; break;
} }
case RPC_CMD_FREE_BUFFER: { case RPC_CMD_FREE_BUFFER: {
ok = server.free_buffer(input); rpc_msg_free_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.free_buffer(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break; break;
} }
case RPC_CMD_BUFFER_CLEAR: { case RPC_CMD_BUFFER_CLEAR: {
ok = server.buffer_clear(input); rpc_msg_buffer_clear_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.buffer_clear(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break; break;
} }
case RPC_CMD_SET_TENSOR: { case RPC_CMD_SET_TENSOR: {
ok = server.set_tensor(input); std::vector<uint8_t> input;
if (!recv_msg(sockfd, input)) {
return;
}
if (!server.set_tensor(input)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break; break;
} }
case RPC_CMD_GET_TENSOR: { case RPC_CMD_GET_TENSOR: {
ok = server.get_tensor(input, output); rpc_msg_get_tensor_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
std::vector<uint8_t> response;
if (!server.get_tensor(request, response)) {
return;
}
if (!send_msg(sockfd, response.data(), response.size())) {
return;
}
break; break;
} }
case RPC_CMD_COPY_TENSOR: { case RPC_CMD_COPY_TENSOR: {
ok = server.copy_tensor(input, output); rpc_msg_copy_tensor_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_copy_tensor_rsp response;
if (!server.copy_tensor(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break; break;
} }
case RPC_CMD_GRAPH_COMPUTE: { case RPC_CMD_GRAPH_COMPUTE: {
ok = server.graph_compute(input, output); std::vector<uint8_t> input;
if (!recv_msg(sockfd, input)) {
return;
}
rpc_msg_graph_compute_rsp response;
if (!server.graph_compute(input, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break; break;
} }
case RPC_CMD_GET_DEVICE_MEMORY: { case RPC_CMD_GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) | if (!recv_msg(sockfd, nullptr, 0)) {
output.resize(2*sizeof(uint64_t), 0); return;
memcpy(output.data(), &free_mem, sizeof(free_mem)); }
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); rpc_msg_get_device_memory_rsp response;
response.free_mem = free_mem;
response.total_mem = total_mem;
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break; break;
} }
default: { default: {
fprintf(stderr, "Unknown command: %d\n", cmd); fprintf(stderr, "Unknown command: %d\n", cmd);
ok = false; return;
} }
} }
if (!ok) {
break;
}
uint64_t output_size = output.size();
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
break;
}
if (!send_data(sockfd, output.data(), output_size)) {
break;
}
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
#include "mmvq.hpp" #include "mmvq.hpp"
#include "vecdotq.hpp" #include "vecdotq.hpp"
#include <cassert>
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl> template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
@ -13,7 +13,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -37,7 +38,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -61,7 +62,8 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -85,7 +87,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -109,8 +111,8 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -133,7 +135,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -157,8 +159,8 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -181,7 +183,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -205,8 +207,8 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -229,7 +231,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -253,8 +255,8 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -277,7 +279,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -301,8 +303,8 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -325,7 +327,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -349,8 +351,8 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -373,7 +375,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -397,8 +399,8 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -421,7 +423,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -446,8 +448,8 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
} }
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
assert(blocks_per_warp>0);
// partial sum for each thread // partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
@ -470,7 +472,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }
@ -487,7 +489,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK4_0 == 0); GGML_ASSERT(ncols % QK4_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -495,7 +497,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>( VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -511,7 +513,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK4_1 == 0); GGML_ASSERT(ncols % QK4_1 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -519,7 +521,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>( VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -535,7 +537,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK5_0 == 0); GGML_ASSERT(ncols % QK5_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -543,7 +545,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>( VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -559,7 +561,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK5_1 == 0); GGML_ASSERT(ncols % QK5_1 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -567,7 +569,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>( VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -583,7 +585,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK8_0 == 0); GGML_ASSERT(ncols % QK8_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -591,7 +593,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>( VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -607,7 +609,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -615,7 +617,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>( VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -631,7 +633,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -639,7 +641,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>( VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -655,7 +657,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -663,7 +665,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>( VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -679,7 +681,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -687,7 +689,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>( VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -703,7 +705,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -711,7 +713,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>( VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -728,13 +730,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>( mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -749,7 +751,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -759,7 +761,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>( mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -774,7 +776,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -784,7 +786,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>( mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -799,7 +801,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -809,7 +811,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>( mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -824,7 +826,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -833,7 +835,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>( mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -848,7 +850,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
@ -858,7 +860,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>( mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -873,13 +875,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>( mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -894,14 +896,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK4_NL == 0); GGML_ASSERT(ncols % QK4_NL == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>( mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
@ -916,14 +918,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for( cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(WARP_SIZE)]] { [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>( mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });

View file

@ -324,8 +324,9 @@ struct ggml_logger_state {
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL}; static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) { static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
if (format == NULL) if (format == NULL) {
return; return;
}
va_list args_copy; va_list args_copy;
va_copy(args_copy, args); va_copy(args_copy, args);
char buffer[128]; char buffer[128];
@ -3483,7 +3484,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes; size_t nbytes;
size_t blck_size = ggml_blck_size(tensor->type); const size_t blck_size = ggml_blck_size(tensor->type);
if (blck_size == 1) { if (blck_size == 1) {
nbytes = ggml_type_size(tensor->type); nbytes = ggml_type_size(tensor->type);
for (int i = 0; i < GGML_MAX_DIMS; ++i) { for (int i = 0; i < GGML_MAX_DIMS; ++i) {
@ -3872,10 +3873,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
}, },
}; };
for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
g_state.contexts[i].used = false;
}
const uint64_t t_end = ggml_time_us(); UNUSED(t_end); const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
@ -15778,6 +15775,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
ggml_to_float_t const v_to_float = type_traits[v->type].to_float; ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
// loop over n_batch and n_head // loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// q indices // q indices
@ -23356,6 +23356,14 @@ int ggml_cpu_has_avx512_bf16(void) {
#endif #endif
} }
int ggml_cpu_has_amx_int8(void) {
#if defined(__AMX_INT8__)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_fma(void) { int ggml_cpu_has_fma(void) {
#if defined(__FMA__) #if defined(__FMA__)
return 1; return 1;

View file

@ -1489,7 +1489,7 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr,};
if (llama_decode(ctx_llama, batch)) { if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "\n%s : failed to eval image\n", __func__); fprintf(stderr, "\n%s : failed to eval image\n", __func__);
return false; return false;
@ -2004,7 +2004,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
//determine mem per token //determine mem per token
std::vector<int> tmp = {1, 2, 3, 4}; std::vector<int> tmp = {1, 2, 3, 4};
llama_kv_cache_clear(llama_ctx_v4); llama_kv_cache_clear(llama_ctx_v4);
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size()));
if(er!=0) if(er!=0)
{ {
printf("\nLLAMA EVAL returned nonzero: %d\n",er); printf("\nLLAMA EVAL returned nonzero: %d\n",er);
@ -3061,7 +3061,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
else if(file_format == FileFormat::GGUF_GENERIC) else if(file_format == FileFormat::GGUF_GENERIC)
{ {
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize, n_past, 0))==0); evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize))==0);
} }
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
{ {
@ -3432,7 +3432,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if(i>0 && sepsize>0) if(i>0 && sepsize>0)
{ {
//add a separator between each image //add a separator between each image
auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize, n_past, 0)); auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize));
if(evr!=0) if(evr!=0)
{ {
printf("\nError when appending llava separator: %d\n",evr); printf("\nError when appending llava separator: %d\n",evr);

View file

@ -217,6 +217,7 @@ extern "C" {
typedef struct llama_token_data_array { typedef struct llama_token_data_array {
// TODO: consider SoA // TODO: consider SoA
// NOTE: this pointer can be modified by the samplers
llama_token_data * data; llama_token_data * data;
size_t size; size_t size;
int64_t selected; // this is the index in the data array (i.e. not the token id) int64_t selected; // this is the index in the data array (i.e. not the token id)
@ -232,8 +233,11 @@ extern "C" {
// - token : the token ids of the input (used when embd is NULL) // - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence // - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// - seq_id : the sequence to which the respective token belongs // - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
// //
typedef struct llama_batch { typedef struct llama_batch {
int32_t n_tokens; int32_t n_tokens;
@ -244,15 +248,6 @@ extern "C" {
int32_t * n_seq_id; int32_t * n_seq_id;
llama_seq_id ** seq_id; llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output" int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
//
// pos[i] = all_pos_0 + i*all_pos_1
//
llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch; } llama_batch;
enum llama_model_kv_override_type { enum llama_model_kv_override_type {
@ -778,15 +773,15 @@ extern "C" {
// Decoding // Decoding
// //
// Return batch for single sequence of tokens starting at pos_0 // Return batch for single sequence of tokens
// The sequence ID will be fixed to 0
// The position of the tokens will be tracked automatically by llama_decode
// //
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
// //
LLAMA_API struct llama_batch llama_batch_get_one( LLAMA_API struct llama_batch llama_batch_get_one(
llama_token * tokens, llama_token * tokens,
int32_t n_tokens, int32_t n_tokens);
llama_pos pos_0,
llama_seq_id seq_id);
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids // Each token can be assigned up to n_seq_max sequence ids
@ -955,12 +950,6 @@ extern "C" {
int32_t lstrip, int32_t lstrip,
bool special); bool special);
// check if token0 is contained as a prefix in token1
LLAMA_API bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1);
/// @details Convert the provided tokens into text (inverse of llama_tokenize()). /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text. /// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max. /// @return Returns the number of chars/bytes on success, no more than text_len_max.
@ -1083,12 +1072,13 @@ extern "C" {
// available samplers: // available samplers:
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void); LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void); DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
"will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
@ -1104,6 +1094,8 @@ extern "C" {
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
/// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.

View file

@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
} }
*/ */
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
if (temp <= 0.0f) {
// find the token with the highest logit and set the rest to -inf
size_t max_i = 0;
float max_l = cur_p->data[0].logit;
for (size_t i = 1; i < cur_p->size; ++i) {
if (cur_p->data[i ].logit > max_l) {
cur_p->data[max_i].logit = -INFINITY;
max_i = i;
max_l = cur_p->data[i].logit;
} else {
cur_p->data[i].logit = -INFINITY;
}
}
return;
}
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= temp;
}
}
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
GGML_ASSERT(cur_p->size > 0); GGML_ASSERT(cur_p->size > 0);
@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dist *) smpl->ctx; auto * ctx = (llama_sampler_dist *) smpl->ctx;
llama_sampler_softmax_impl(cur_p);
cur_p->selected = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
} }
@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_temp *) smpl->ctx; const auto * ctx = (llama_sampler_temp *) smpl->ctx;
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= ctx->temp; llama_sampler_temp_impl(cur_p, ctx->temp);
}
} }
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
if (ctx->delta > 0) { if (ctx->delta > 0) {
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
const float max_temp = ctx->temp + ctx->delta; const float max_temp = ctx->temp + ctx->delta;
float exponent_val = ctx->exponent; float exponent_val = ctx->exponent;
// no need to do anything if there is only one (or zero) candidates // no need to do anything if there is only one (or zero) candidates
@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
#endif #endif
// Apply the dynamically calculated temperature scaling // Apply the dynamically calculated temperature scaling
for (size_t i = 0; i < cur_p->size; ++i) { llama_sampler_temp_impl(cur_p, dyn_temp);
cur_p->data[i].logit /= dyn_temp;
}
// Re-compute softmax probabilities after scaling logits with dynamic temperature // Re-compute softmax probabilities after scaling logits with dynamic temperature
const double max_l_double = cur_p->data[0].logit; const double max_l_double = cur_p->data[0].logit;
@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
} }
#endif #endif
} else { } else {
for (size_t i = 0; i < cur_p->size; ++i) { llama_sampler_temp_impl(cur_p, ctx->temp);
cur_p->data[i].logit /= ctx->temp;
}
} }
} }
@ -1745,6 +1768,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
struct llama_sampler_infill { struct llama_sampler_infill {
const struct llama_vocab * vocab; const struct llama_vocab * vocab;
std::vector<char> buf0;
std::vector<char> buf1;
}; };
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
@ -1810,27 +1836,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
size_t n_combined = 0; GGML_UNUSED(n_combined); size_t n_combined = 0; GGML_UNUSED(n_combined);
// combine tokens with common prefix // combine tokens with common prefix
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
for (size_t j = 0; j < cur_p->size; ++j) { for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
if (cur_p->data[i].logit == -INFINITY) { if (cur_p->data[i0].logit == -INFINITY) {
break; break;
} }
if (i == j || cur_p->data[j].logit == -INFINITY) { if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
continue; continue;
} }
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) { int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
if (cur_p->data[i].p > cur_p->data[j].p) { if (len0 < 0) {
cur_p->data[i].p += cur_p->data[j].p; ctx->buf0.resize(len0);
cur_p->data[j].logit = -INFINITY; len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
cur_p->data[j].p = 0.0f; assert(len0 > 0);
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
cur_p->data[i].p = 0.0f;
} }
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
if (len1 < 0) {
ctx->buf1.resize(len1);
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
assert(len1 > 0);
}
// token i0 is a prefix of token i1
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
int dst = i0;
int src = i1;
// merge into the token with higher probability
if (cur_p->data[i1].p > cur_p->data[i0].p) {
std::swap(dst, src);
}
cur_p->data[dst].p += cur_p->data[src].p;
cur_p->data[src].logit = -INFINITY;
cur_p->data[src].p = 0.0f;
n_combined++; n_combined++;
} }
} }
@ -1936,6 +1979,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
/* .iface = */ &llama_sampler_infill_i, /* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill { /* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab, /* .vocab = */ &vocab,
/* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512),
}, },
}; };
} }

View file

@ -2118,23 +2118,6 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
return 0; return 0;
} }
bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1) {
char text_buf_0[128];
char text_buf_1[128];
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
if (len0 <= 0 || len1 <= 0) {
return false;
}
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
}
int32_t llama_detokenize_impl( int32_t llama_detokenize_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
const llama_token * tokens, const llama_token * tokens,

File diff suppressed because it is too large Load diff