mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Alone in the darkness
They're coming for you I know they will try to catch me too Alone in the darkness They're calling for you There's nowhere to run for cover
This commit is contained in:
commit
94a5a27b85
44 changed files with 6803 additions and 2143 deletions
|
@ -1098,7 +1098,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
|
||||
add_opt(common_arg(
|
||||
{"--attention"}, "{causal,non,causal}",
|
||||
{"--attention"}, "{causal,non-causal}",
|
||||
"attention type for embeddings, use model default if unspecified",
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
|
||||
|
@ -1696,7 +1696,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_BENCH}));
|
||||
add_opt(common_arg(
|
||||
{"--embd-normalize"}, "N",
|
||||
string_format("normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize),
|
||||
string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize),
|
||||
[](common_params & params, int value) {
|
||||
params.embd_normalize = value;
|
||||
}
|
||||
|
@ -1710,7 +1710,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--embd-separator"}, "STRING",
|
||||
"separator of embendings (default \\n) for example \"<#sep#>\"",
|
||||
"separator of embeddings (default \\n) for example \"<#sep#>\"",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.embd_sep = value;
|
||||
}
|
||||
|
|
|
@ -957,7 +957,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
|
||||
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == -1) {
|
||||
decoder_start_token_id = bos;
|
||||
|
@ -966,7 +966,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
tmp.push_back(decoder_start_token_id);
|
||||
}
|
||||
if (llama_model_has_decoder(model)) {
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
|
||||
}
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
|
@ -1037,7 +1037,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
|||
return GGML_TYPE_Q5_1;
|
||||
}
|
||||
|
||||
throw std::runtime_error("Invalid cache type: " + s);
|
||||
throw std::runtime_error("Unsupported cache type: " + s);
|
||||
}
|
||||
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
||||
|
|
|
@ -270,9 +270,9 @@ struct common_params {
|
|||
|
||||
// embedding
|
||||
bool embedding = false; // get only sentence embedding
|
||||
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||
std::string embd_sep = "\n"; // separator of embendings
|
||||
std::string embd_sep = "\n"; // separator of embeddings
|
||||
bool reranking = false; // enable reranking support on server
|
||||
|
||||
// server params
|
||||
|
|
|
@ -171,7 +171,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
params.penalize_nl,
|
||||
params.ignore_eos));
|
||||
|
||||
if (params.temp > 0.0f) {
|
||||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
|
@ -203,7 +202,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
} else if (params.mirostat == 1) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
|
@ -214,18 +212,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
} else {
|
||||
GGML_ASSERT(false && "unknown mirostat version");
|
||||
}
|
||||
} else {
|
||||
if (params.n_probs > 0) {
|
||||
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
|
||||
//
|
||||
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
|
||||
// it is much faster, since we avoid sorting all tokens and should give a good approximation
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -2864,6 +2864,9 @@ class Rwkv6Model(Model):
|
|||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "rwkv-world"
|
||||
# hack: Add '\n\n' as the EOT token to make it chat normally
|
||||
special_vocab._set_special_token("eot", 261)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
|
|
|
@ -348,6 +348,9 @@ if __name__ == '__main__':
|
|||
if ".base_layer.weight" in name:
|
||||
continue
|
||||
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
||||
if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
|
||||
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
|
||||
logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()")
|
||||
sys.exit(1)
|
||||
|
||||
if base_name in tensor_map:
|
||||
|
|
|
@ -74,7 +74,6 @@ int main(int argc, char ** argv) {
|
|||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
|
|
@ -339,7 +339,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
|
|||
|
||||
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
||||
llama_kv_cache_clear(ctx);
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -131,7 +131,7 @@ static bool run(llama_context * ctx, const common_params & params) {
|
|||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
|||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
|
|
|
@ -46,7 +46,6 @@ actor LlamaContext {
|
|||
let sparams = llama_sampler_chain_default_params()
|
||||
self.sampling = llama_sampler_chain_init(sparams)
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
||||
}
|
||||
|
||||
|
|
783
examples/llama.vim
Normal file
783
examples/llama.vim
Normal file
|
@ -0,0 +1,783 @@
|
|||
" LLM-based text completion using llama.cpp
|
||||
"
|
||||
" requires:
|
||||
"
|
||||
" - neovim or vim
|
||||
" - curl
|
||||
" - llama.cpp server instance
|
||||
" - FIM-compatible model
|
||||
"
|
||||
" sample config:
|
||||
"
|
||||
" - Tab - accept the current suggestion
|
||||
" - Shift+Tab - accept just the first line of the suggestion
|
||||
" - Ctrl+F - toggle FIM completion manually
|
||||
"
|
||||
" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim
|
||||
"
|
||||
" start the llama.cpp server with a FIM-compatible model. for example:
|
||||
"
|
||||
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256
|
||||
"
|
||||
" --batch-size [512, model max context]
|
||||
"
|
||||
" adjust the batch size to control how much of the provided local context will be used during the inference
|
||||
" lower values will use smaller part of the context around the cursor, which will result in faster processing
|
||||
"
|
||||
" --ubatch-size [64, 2048]
|
||||
"
|
||||
" chunks the batch into smaller chunks for faster processing
|
||||
" depends on the specific hardware. use llama-bench to profile and determine the best size
|
||||
"
|
||||
" --cache-reuse (ge:llama_config.n_predict, 1024]
|
||||
"
|
||||
" this should be either 0 (disabled) or strictly larger than g:llama_config.n_predict
|
||||
" using non-zero value enables context reuse on the server side which dramatically improves the performance at
|
||||
" large contexts. a value of 256 should be good for all cases
|
||||
"
|
||||
" run this once to initialise llama.vim:
|
||||
"
|
||||
" :call llama#init()
|
||||
"
|
||||
" more info: https://github.com/ggerganov/llama.cpp/pull/9787
|
||||
"
|
||||
|
||||
" colors (adjust to your liking)
|
||||
highlight llama_hl_hint guifg=#ff772f ctermfg=202
|
||||
highlight llama_hl_info guifg=#77ff2f ctermfg=119
|
||||
|
||||
" general parameters:
|
||||
"
|
||||
" endpoint: llama.cpp server endpoint
|
||||
" n_prefix: number of lines before the cursor location to include in the local prefix
|
||||
" n_suffix: number of lines after the cursor location to include in the local suffix
|
||||
" n_predict: max number of tokens to predict
|
||||
" t_max_prompt_ms: max alloted time for the prompt processing (TODO: not yet supported)
|
||||
" t_max_predict_ms: max alloted time for the prediction
|
||||
" show_info: show extra info about the inference (0 - disabled, 1 - statusline, 2 - inline)
|
||||
" auto_fim: trigger FIM completion automatically on cursor movement
|
||||
" max_line_suffix: do not auto-trigger FIM completion if there are more than this number of characters to the right of the cursor
|
||||
"
|
||||
" ring buffer of chunks, accumulated with time upon:
|
||||
"
|
||||
" - completion request
|
||||
" - yank
|
||||
" - entering a buffer
|
||||
" - leaving a buffer
|
||||
" - writing a file
|
||||
"
|
||||
" parameters for the ring-buffer with extra context:
|
||||
"
|
||||
" ring_n_chunks: max number of chunks to pass as extra context to the server (0 to disable)
|
||||
" ring_chunk_size: max size of the chunks (in number of lines)
|
||||
" note: adjust these numbers so that you don't overrun your context
|
||||
" at ring_n_chunks = 64 and ring_chunk_size = 64 you need ~32k context
|
||||
" ring_scope: the range around the cursor position (in number of lines) for gathering chunks after FIM
|
||||
" ring_update_ms: how often to process queued chunks in normal mode
|
||||
"
|
||||
let s:default_config = {
|
||||
\ 'endpoint': 'http://127.0.0.1:8012/infill',
|
||||
\ 'n_prefix': 256,
|
||||
\ 'n_suffix': 64,
|
||||
\ 'n_predict': 128,
|
||||
\ 't_max_prompt_ms': 500,
|
||||
\ 't_max_predict_ms': 3000,
|
||||
\ 'show_info': 2,
|
||||
\ 'auto_fim': v:true,
|
||||
\ 'max_line_suffix': 8,
|
||||
\ 'ring_n_chunks': 64,
|
||||
\ 'ring_chunk_size': 64,
|
||||
\ 'ring_scope': 1024,
|
||||
\ 'ring_update_ms': 1000,
|
||||
\ }
|
||||
|
||||
let g:llama_config = get(g:, 'llama_config', s:default_config)
|
||||
|
||||
function! s:get_indent(str)
|
||||
let l:count = 0
|
||||
for i in range(len(a:str))
|
||||
if a:str[i] == "\t"
|
||||
let l:count += &tabstop - 1
|
||||
else
|
||||
break
|
||||
endif
|
||||
endfor
|
||||
return l:count
|
||||
endfunction
|
||||
|
||||
function! s:rand(i0, i1) abort
|
||||
return a:i0 + rand() % (a:i1 - a:i0 + 1)
|
||||
endfunction
|
||||
|
||||
function! llama#init()
|
||||
if !executable('curl')
|
||||
echohl WarningMsg
|
||||
echo 'llama.vim requires the "curl" command to be available'
|
||||
echohl None
|
||||
return
|
||||
endif
|
||||
|
||||
let s:pos_x = 0 " cursor position upon start of completion
|
||||
let s:pos_y = 0
|
||||
|
||||
let s:line_cur = ''
|
||||
|
||||
let s:line_cur_prefix = ''
|
||||
let s:line_cur_suffix = ''
|
||||
|
||||
let s:ring_chunks = [] " current set of chunks used as extra context
|
||||
let s:ring_queued = [] " chunks that are queued to be sent for processing
|
||||
let s:ring_n_evict = 0
|
||||
|
||||
let s:hint_shown = v:false
|
||||
let s:pos_y_pick = -9999 " last y where we picked a chunk
|
||||
let s:pos_dx = 0
|
||||
let s:content = []
|
||||
let s:can_accept = v:false
|
||||
|
||||
let s:timer_fim = -1
|
||||
let s:t_fim_start = reltime() " used to measure total FIM time
|
||||
let s:t_last_move = reltime() " last time the cursor moved
|
||||
|
||||
let s:current_job = v:null
|
||||
|
||||
let s:ghost_text_nvim = exists('*nvim_buf_get_mark')
|
||||
let s:ghost_text_vim = has('textprop')
|
||||
|
||||
if s:ghost_text_vim
|
||||
let s:hlgroup_hint = 'llama_hl_hint'
|
||||
let s:hlgroup_info = 'llama_hl_info'
|
||||
|
||||
if empty(prop_type_get(s:hlgroup_hint))
|
||||
call prop_type_add(s:hlgroup_hint, {'highlight': s:hlgroup_hint})
|
||||
endif
|
||||
if empty(prop_type_get(s:hlgroup_info))
|
||||
call prop_type_add(s:hlgroup_info, {'highlight': s:hlgroup_info})
|
||||
endif
|
||||
endif
|
||||
|
||||
augroup llama
|
||||
autocmd!
|
||||
autocmd InsertEnter * inoremap <expr> <silent> <C-F> llama#fim_inline(v:false)
|
||||
autocmd InsertLeavePre * call llama#fim_cancel()
|
||||
|
||||
autocmd CursorMoved * call s:on_move()
|
||||
autocmd CursorMovedI * call s:on_move()
|
||||
autocmd CompleteChanged * call llama#fim_cancel()
|
||||
|
||||
if g:llama_config.auto_fim
|
||||
autocmd CursorMovedI * call llama#fim(v:true)
|
||||
endif
|
||||
|
||||
" gather chunks upon yanking
|
||||
autocmd TextYankPost * if v:event.operator ==# 'y' | call s:pick_chunk(v:event.regcontents, v:false, v:true) | endif
|
||||
|
||||
" gather chunks upon entering/leaving a buffer
|
||||
autocmd BufEnter * call timer_start(100, {-> s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)})
|
||||
autocmd BufLeave * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)
|
||||
|
||||
" gather chunk upon saving the file
|
||||
autocmd BufWritePost * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)
|
||||
augroup END
|
||||
|
||||
silent! call llama#fim_cancel()
|
||||
|
||||
" init background update of the ring buffer
|
||||
if g:llama_config.ring_n_chunks > 0
|
||||
call s:ring_update()
|
||||
endif
|
||||
endfunction
|
||||
|
||||
" compute how similar two chunks of text are
|
||||
" 0 - no similarity, 1 - high similarity
|
||||
" TODO: figure out something better
|
||||
function! s:chunk_sim(c0, c1)
|
||||
let l:lines0 = len(a:c0)
|
||||
let l:lines1 = len(a:c1)
|
||||
|
||||
let l:common = 0
|
||||
|
||||
for l:line0 in a:c0
|
||||
for l:line1 in a:c1
|
||||
if l:line0 == l:line1
|
||||
let l:common += 1
|
||||
break
|
||||
endif
|
||||
endfor
|
||||
endfor
|
||||
|
||||
return 2.0 * l:common / (l:lines0 + l:lines1)
|
||||
endfunction
|
||||
|
||||
" pick a random chunk of size g:llama_config.ring_chunk_size from the provided text and queue it for processing
|
||||
"
|
||||
" no_mod - do not pick chunks from buffers with pending changes
|
||||
" do_evict - evict chunks that are very similar to the new one
|
||||
"
|
||||
function! s:pick_chunk(text, no_mod, do_evict)
|
||||
" do not pick chunks from buffers with pending changes or buffers that are not files
|
||||
if a:no_mod && (getbufvar(bufnr('%'), '&modified') || !buflisted(bufnr('%')) || !filereadable(expand('%')))
|
||||
return
|
||||
endif
|
||||
|
||||
" if the extra context option is disabled - do nothing
|
||||
if g:llama_config.ring_n_chunks <= 0
|
||||
return
|
||||
endif
|
||||
|
||||
" don't pick very small chunks
|
||||
if len(a:text) < 3
|
||||
return
|
||||
endif
|
||||
|
||||
if len(a:text) + 1 < g:llama_config.ring_chunk_size
|
||||
let l:chunk = a:text
|
||||
else
|
||||
let l:l0 = s:rand(0, max([0, len(a:text) - g:llama_config.ring_chunk_size/2]))
|
||||
let l:l1 = min([l:l0 + g:llama_config.ring_chunk_size/2, len(a:text)])
|
||||
|
||||
let l:chunk = a:text[l:l0:l:l1]
|
||||
endif
|
||||
|
||||
let l:chunk_str = join(l:chunk, "\n") . "\n"
|
||||
|
||||
" check if this chunk is already added
|
||||
let l:exist = v:false
|
||||
|
||||
for i in range(len(s:ring_chunks))
|
||||
if s:ring_chunks[i].data == l:chunk
|
||||
let l:exist = v:true
|
||||
break
|
||||
endif
|
||||
endfor
|
||||
|
||||
for i in range(len(s:ring_queued))
|
||||
if s:ring_queued[i].data == l:chunk
|
||||
let l:exist = v:true
|
||||
break
|
||||
endif
|
||||
endfor
|
||||
|
||||
if l:exist
|
||||
return
|
||||
endif
|
||||
|
||||
" evict queued chunks that are very similar to the new one
|
||||
for i in range(len(s:ring_queued) - 1, 0, -1)
|
||||
if s:chunk_sim(s:ring_queued[i].data, l:chunk) > 0.9
|
||||
if a:do_evict
|
||||
call remove(s:ring_queued, i)
|
||||
let s:ring_n_evict += 1
|
||||
else
|
||||
return
|
||||
endif
|
||||
endif
|
||||
endfor
|
||||
|
||||
" also from s:ring_chunks
|
||||
for i in range(len(s:ring_chunks) - 1, 0, -1)
|
||||
if s:chunk_sim(s:ring_chunks[i].data, l:chunk) > 0.9
|
||||
if a:do_evict
|
||||
call remove(s:ring_chunks, i)
|
||||
let s:ring_n_evict += 1
|
||||
else
|
||||
return
|
||||
endif
|
||||
endif
|
||||
endfor
|
||||
|
||||
" TODO: become parameter ?
|
||||
if len(s:ring_queued) == 16
|
||||
call remove(s:ring_queued, 0)
|
||||
endif
|
||||
|
||||
call add(s:ring_queued, {'data': l:chunk, 'str': l:chunk_str, 'time': reltime(), 'filename': expand('%')})
|
||||
|
||||
"let &statusline = 'extra context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued)
|
||||
endfunction
|
||||
|
||||
" picks a queued chunk, sends it for processing and adds it to s:ring_chunks
|
||||
" called every g:llama_config.ring_update_ms
|
||||
function! s:ring_update()
|
||||
call timer_start(g:llama_config.ring_update_ms, {-> s:ring_update()})
|
||||
|
||||
" update only if in normal mode or if the cursor hasn't moved for a while
|
||||
if mode() !=# 'n' && reltimefloat(reltime(s:t_last_move)) < 3.0
|
||||
return
|
||||
endif
|
||||
|
||||
if len(s:ring_queued) == 0
|
||||
return
|
||||
endif
|
||||
|
||||
" move the first queued chunk to the ring buffer
|
||||
if len(s:ring_chunks) == g:llama_config.ring_n_chunks
|
||||
call remove(s:ring_chunks, 0)
|
||||
endif
|
||||
|
||||
call add(s:ring_chunks, remove(s:ring_queued, 0))
|
||||
|
||||
"let &statusline = 'updated context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued)
|
||||
|
||||
" send asynchronous job with the new extra context so that it is ready for the next FIM
|
||||
let l:extra_context = []
|
||||
for l:chunk in s:ring_chunks
|
||||
call add(l:extra_context, {
|
||||
\ 'text': l:chunk.str,
|
||||
\ 'time': l:chunk.time,
|
||||
\ 'filename': l:chunk.filename
|
||||
\ })
|
||||
endfor
|
||||
|
||||
" no samplers needed here
|
||||
let l:request = json_encode({
|
||||
\ 'input_prefix': "",
|
||||
\ 'input_suffix': "",
|
||||
\ 'input_extra': l:extra_context,
|
||||
\ 'prompt': "",
|
||||
\ 'n_predict': 1,
|
||||
\ 'temperature': 0.0,
|
||||
\ 'stream': v:false,
|
||||
\ 'samplers': ["temperature"],
|
||||
\ 'cache_prompt': v:true,
|
||||
\ 't_max_prompt_ms': 1,
|
||||
\ 't_max_predict_ms': 1
|
||||
\ })
|
||||
|
||||
let l:curl_command = [
|
||||
\ "curl",
|
||||
\ "--silent",
|
||||
\ "--no-buffer",
|
||||
\ "--request", "POST",
|
||||
\ "--url", g:llama_config.endpoint,
|
||||
\ "--header", "Content-Type: application/json",
|
||||
\ "--data", l:request
|
||||
\ ]
|
||||
|
||||
" no callbacks because we don't need to process the response
|
||||
if s:ghost_text_nvim
|
||||
call jobstart(l:curl_command, {})
|
||||
elseif s:ghost_text_vim
|
||||
call job_start(l:curl_command, {})
|
||||
endif
|
||||
endfunction
|
||||
|
||||
" necessary for 'inoremap <expr>'
|
||||
function! llama#fim_inline(is_auto) abort
|
||||
call llama#fim(a:is_auto)
|
||||
return ''
|
||||
endfunction
|
||||
|
||||
" the main FIM call
|
||||
" takes local context around the cursor and sends it together with the extra context to the server for completion
|
||||
function! llama#fim(is_auto) abort
|
||||
" we already have a suggestion for the current cursor position
|
||||
if s:hint_shown && !a:is_auto
|
||||
call llama#fim_cancel()
|
||||
return
|
||||
endif
|
||||
|
||||
call llama#fim_cancel()
|
||||
|
||||
" avoid sending repeated requests too fast
|
||||
if reltimefloat(reltime(s:t_fim_start)) < 0.6
|
||||
if s:timer_fim != -1
|
||||
call timer_stop(s:timer_fim)
|
||||
let s:timer_fim = -1
|
||||
endif
|
||||
|
||||
let s:t_fim_start = reltime()
|
||||
let s:timer_fim = timer_start(600, {-> llama#fim(v:true)})
|
||||
return
|
||||
endif
|
||||
|
||||
let s:t_fim_start = reltime()
|
||||
|
||||
let s:content = []
|
||||
let s:can_accept = v:false
|
||||
|
||||
let s:pos_x = col('.') - 1
|
||||
let s:pos_y = line('.')
|
||||
let l:max_y = line('$')
|
||||
|
||||
let l:lines_prefix = getline(max([1, s:pos_y - g:llama_config.n_prefix]), s:pos_y - 1)
|
||||
let l:lines_suffix = getline(s:pos_y + 1, min([l:max_y, s:pos_y + g:llama_config.n_suffix]))
|
||||
|
||||
let s:line_cur = getline('.')
|
||||
|
||||
let s:line_cur_prefix = strpart(s:line_cur, 0, s:pos_x)
|
||||
let s:line_cur_suffix = strpart(s:line_cur, s:pos_x)
|
||||
|
||||
if a:is_auto && len(s:line_cur_suffix) > g:llama_config.max_line_suffix
|
||||
return
|
||||
endif
|
||||
|
||||
let l:prefix = ""
|
||||
\ . join(l:lines_prefix, "\n")
|
||||
\ . "\n"
|
||||
|
||||
let l:prompt = ""
|
||||
\ . s:line_cur_prefix
|
||||
|
||||
let l:suffix = ""
|
||||
\ . s:line_cur_suffix
|
||||
\ . "\n"
|
||||
\ . join(l:lines_suffix, "\n")
|
||||
\ . "\n"
|
||||
|
||||
" prepare the extra context data
|
||||
let l:extra_context = []
|
||||
for l:chunk in s:ring_chunks
|
||||
call add(l:extra_context, {
|
||||
\ 'text': l:chunk.str,
|
||||
\ 'time': l:chunk.time,
|
||||
\ 'filename': l:chunk.filename
|
||||
\ })
|
||||
endfor
|
||||
|
||||
" the indentation of the current line
|
||||
let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))
|
||||
|
||||
let l:request = json_encode({
|
||||
\ 'input_prefix': l:prefix,
|
||||
\ 'input_suffix': l:suffix,
|
||||
\ 'input_extra': l:extra_context,
|
||||
\ 'prompt': l:prompt,
|
||||
\ 'n_predict': g:llama_config.n_predict,
|
||||
\ 'n_indent': l:indent,
|
||||
\ 'top_k': 40,
|
||||
\ 'top_p': 0.99,
|
||||
\ 'stream': v:false,
|
||||
\ 'samplers': ["top_k", "top_p", "infill"],
|
||||
\ 'cache_prompt': v:true,
|
||||
\ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms,
|
||||
\ 't_max_predict_ms': g:llama_config.t_max_predict_ms
|
||||
\ })
|
||||
|
||||
let l:curl_command = [
|
||||
\ "curl",
|
||||
\ "--silent",
|
||||
\ "--no-buffer",
|
||||
\ "--request", "POST",
|
||||
\ "--url", g:llama_config.endpoint,
|
||||
\ "--header", "Content-Type: application/json",
|
||||
\ "--data", l:request
|
||||
\ ]
|
||||
|
||||
if s:current_job != v:null
|
||||
if s:ghost_text_nvim
|
||||
call jobstop(s:current_job)
|
||||
elseif s:ghost_text_vim
|
||||
call job_stop(s:current_job)
|
||||
endif
|
||||
endif
|
||||
|
||||
" send the request asynchronously
|
||||
if s:ghost_text_nvim
|
||||
let s:current_job = jobstart(l:curl_command, {
|
||||
\ 'on_stdout': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
|
||||
\ 'on_exit': function('s:fim_on_exit'),
|
||||
\ 'stdout_buffered': v:true
|
||||
\ })
|
||||
elseif s:ghost_text_vim
|
||||
let s:current_job = job_start(l:curl_command, {
|
||||
\ 'out_cb': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]),
|
||||
\ 'exit_cb': function('s:fim_on_exit')
|
||||
\ })
|
||||
endif
|
||||
|
||||
" TODO: per-file location
|
||||
let l:delta_y = abs(s:pos_y - s:pos_y_pick)
|
||||
|
||||
" gather some extra context nearby and process it in the background
|
||||
" only gather chunks if the cursor has moved a lot
|
||||
" TODO: something more clever? reranking?
|
||||
if a:is_auto && l:delta_y > 32
|
||||
" expand the prefix even further
|
||||
call s:pick_chunk(getline(max([1, s:pos_y - g:llama_config.ring_scope]), max([1, s:pos_y - g:llama_config.n_prefix])), v:false, v:false)
|
||||
|
||||
" pick a suffix chunk
|
||||
call s:pick_chunk(getline(min([l:max_y, s:pos_y + g:llama_config.n_suffix]), min([l:max_y, s:pos_y + g:llama_config.n_suffix + g:llama_config.ring_chunk_size])), v:false, v:false)
|
||||
|
||||
let s:pos_y_pick = s:pos_y
|
||||
endif
|
||||
endfunction
|
||||
|
||||
" if first_line == v:true accept only the first line of the response
|
||||
function! llama#fim_accept(first_line)
|
||||
" insert the suggestion at the cursor location
|
||||
if s:can_accept && len(s:content) > 0
|
||||
call setline(s:pos_y, s:line_cur[:(s:pos_x - 1)] . s:content[0])
|
||||
if len(s:content) > 1
|
||||
if !a:first_line
|
||||
call append(s:pos_y, s:content[1:-1])
|
||||
endif
|
||||
endif
|
||||
|
||||
" move the cursor to the end of the accepted text
|
||||
if !a:first_line && len(s:content) > 1
|
||||
call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx + 1)
|
||||
else
|
||||
call cursor(s:pos_y, s:pos_x + len(s:content[0]))
|
||||
endif
|
||||
endif
|
||||
|
||||
call llama#fim_cancel()
|
||||
endfunction
|
||||
|
||||
function! llama#fim_cancel()
|
||||
let s:hint_shown = v:false
|
||||
|
||||
" clear the virtual text
|
||||
let l:bufnr = bufnr('%')
|
||||
|
||||
if s:ghost_text_nvim
|
||||
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
|
||||
elseif s:ghost_text_vim
|
||||
call prop_remove({'type': s:hlgroup_hint, 'all': v:true})
|
||||
call prop_remove({'type': s:hlgroup_info, 'all': v:true})
|
||||
endif
|
||||
|
||||
" remove the mappings
|
||||
silent! iunmap <buffer> <Tab>
|
||||
silent! iunmap <buffer> <S-Tab>
|
||||
silent! iunmap <buffer> <Esc>
|
||||
endfunction
|
||||
|
||||
function! s:on_move()
|
||||
let s:t_last_move = reltime()
|
||||
|
||||
call llama#fim_cancel()
|
||||
endfunction
|
||||
|
||||
" callback that processes the FIM result from the server and displays the suggestion
|
||||
function! s:fim_on_stdout(pos_x, pos_y, is_auto, job_id, data, event = v:null)
|
||||
if s:ghost_text_nvim
|
||||
let l:raw = join(a:data, "\n")
|
||||
elseif s:ghost_text_vim
|
||||
let l:raw = a:data
|
||||
endif
|
||||
|
||||
if len(l:raw) == 0
|
||||
return
|
||||
endif
|
||||
|
||||
if a:pos_x != col('.') - 1 || a:pos_y != line('.')
|
||||
return
|
||||
endif
|
||||
|
||||
" show the suggestion only in insert mode
|
||||
if mode() !=# 'i'
|
||||
return
|
||||
endif
|
||||
|
||||
let s:pos_x = a:pos_x
|
||||
let s:pos_y = a:pos_y
|
||||
|
||||
let s:can_accept = v:true
|
||||
let l:has_info = v:false
|
||||
|
||||
if s:can_accept && v:shell_error
|
||||
if !a:is_auto
|
||||
call add(s:content, "<| curl error: is the server on? |>")
|
||||
endif
|
||||
let s:can_accept = v:false
|
||||
endif
|
||||
|
||||
let l:n_prompt = 0
|
||||
let l:t_prompt_ms = 1.0
|
||||
let l:s_prompt = 0
|
||||
|
||||
let l:n_predict = 0
|
||||
let l:t_predict_ms = 1.0
|
||||
let l:s_predict = 0
|
||||
|
||||
" get the generated suggestion
|
||||
if s:can_accept
|
||||
let l:response = json_decode(l:raw)
|
||||
|
||||
for l:part in split(get(l:response, 'content', ''), "\n", 1)
|
||||
call add(s:content, l:part)
|
||||
endfor
|
||||
|
||||
" remove trailing new lines
|
||||
while len(s:content) > 0 && s:content[-1] == ""
|
||||
call remove(s:content, -1)
|
||||
endwhile
|
||||
|
||||
let l:generation_settings = get(l:response, 'generation_settings', {})
|
||||
let l:n_ctx = get(l:generation_settings, 'n_ctx', 0)
|
||||
|
||||
let l:n_cached = get(l:response, 'tokens_cached', 0)
|
||||
let l:truncated = get(l:response, 'truncated', v:false)
|
||||
|
||||
" if response.timings is available
|
||||
if len(get(l:response, 'timings', {})) > 0
|
||||
let l:has_info = v:true
|
||||
let l:timings = get(l:response, 'timings', {})
|
||||
|
||||
let l:n_prompt = get(l:timings, 'prompt_n', 0)
|
||||
let l:t_prompt_ms = get(l:timings, 'prompt_ms', 1)
|
||||
let l:s_prompt = get(l:timings, 'prompt_per_second', 0)
|
||||
|
||||
let l:n_predict = get(l:timings, 'predicted_n', 0)
|
||||
let l:t_predict_ms = get(l:timings, 'predicted_ms', 1)
|
||||
let l:s_predict = get(l:timings, 'predicted_per_second', 0)
|
||||
endif
|
||||
endif
|
||||
|
||||
if len(s:content) == 0
|
||||
call add(s:content, "")
|
||||
let s:can_accept = v:false
|
||||
endif
|
||||
|
||||
if len(s:content) == 0
|
||||
return
|
||||
endif
|
||||
|
||||
" NOTE: the following is logic for discarding predictions that repeat existing text
|
||||
" the code is quite ugly and there is very likely a simpler and more canonical way to implement this
|
||||
"
|
||||
" still, I wonder if there is some better way that avoids having to do these special hacks?
|
||||
" on one hand, the LLM 'sees' the contents of the file before we start editing, so it is normal that it would
|
||||
" start generating whatever we have given it via the extra context. but on the other hand, it's not very
|
||||
" helpful to re-generate the same code that is already there
|
||||
|
||||
" truncate the suggestion if the first line is empty
|
||||
if len(s:content) == 1 && s:content[0] == ""
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" ... and the next lines are repeated
|
||||
if len(s:content) > 1 && s:content[0] == "" && s:content[1:] == getline(s:pos_y + 1, s:pos_y + len(s:content) - 1)
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" truncate the suggestion if it repeats the suffix
|
||||
if len(s:content) == 1 && s:content[0] == s:line_cur_suffix
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" find the first non-empty line (strip whitespace)
|
||||
let l:cmp_y = s:pos_y + 1
|
||||
while l:cmp_y < line('$') && getline(l:cmp_y) =~? '^\s*$'
|
||||
let l:cmp_y += 1
|
||||
endwhile
|
||||
|
||||
if (s:line_cur_prefix . s:content[0]) == getline(l:cmp_y)
|
||||
" truncate the suggestion if it repeats the next line
|
||||
if len(s:content) == 1
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" ... or if the second line of the suggestion is the prefix of line l:cmp_y + 1
|
||||
if len(s:content) == 2 && s:content[-1] == getline(l:cmp_y + 1)[:len(s:content[-1]) - 1]
|
||||
let s:content = [""]
|
||||
endif
|
||||
|
||||
" ... or if the middle chunk of lines of the suggestion is the same as [l:cmp_y + 1, l:cmp_y + len(s:content) - 1)
|
||||
if len(s:content) > 2 && join(s:content[1:-1], "\n") == join(getline(l:cmp_y + 1, l:cmp_y + len(s:content) - 1), "\n")
|
||||
let s:content = [""]
|
||||
endif
|
||||
endif
|
||||
|
||||
" keep only lines that have the same or larger whitespace prefix as s:line_cur_prefix
|
||||
"let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))
|
||||
"for i in range(1, len(s:content) - 1)
|
||||
" if strlen(matchstr(s:content[i], '^\s*')) < l:indent
|
||||
" let s:content = s:content[:i - 1]
|
||||
" break
|
||||
" endif
|
||||
"endfor
|
||||
|
||||
let s:pos_dx = len(s:content[-1])
|
||||
|
||||
let s:content[-1] .= s:line_cur_suffix
|
||||
|
||||
call llama#fim_cancel()
|
||||
|
||||
" display virtual text with the suggestion
|
||||
let l:bufnr = bufnr('%')
|
||||
|
||||
if s:ghost_text_nvim
|
||||
let l:id_vt_fim = nvim_create_namespace('vt_fim')
|
||||
endif
|
||||
|
||||
" construct the info message
|
||||
if g:llama_config.show_info > 0 && l:has_info
|
||||
let l:prefix = ' '
|
||||
|
||||
if l:truncated
|
||||
let l:info = printf("%s | WARNING: the context is full: %d / %d, increase the server context size or reduce g:llama_config.ring_n_chunks",
|
||||
\ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim',
|
||||
\ l:n_cached, l:n_ctx
|
||||
\ )
|
||||
else
|
||||
let l:info = printf("%s | c: %d / %d, r: %d / %d, e: %d, q: %d / 16 | p: %d (%.2f ms, %.2f t/s) | g: %d (%.2f ms, %.2f t/s) | t: %.2f ms",
|
||||
\ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim',
|
||||
\ l:n_cached, l:n_ctx, len(s:ring_chunks), g:llama_config.ring_n_chunks, s:ring_n_evict, len(s:ring_queued),
|
||||
\ l:n_prompt, l:t_prompt_ms, l:s_prompt,
|
||||
\ l:n_predict, l:t_predict_ms, l:s_predict,
|
||||
\ 1000.0 * reltimefloat(reltime(s:t_fim_start))
|
||||
\ )
|
||||
endif
|
||||
|
||||
if g:llama_config.show_info == 1
|
||||
" display the info in the statusline
|
||||
let &statusline = l:info
|
||||
let l:info = ''
|
||||
endif
|
||||
endif
|
||||
|
||||
" display the suggestion and append the info to the end of the first line
|
||||
if s:ghost_text_nvim
|
||||
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
|
||||
\ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']],
|
||||
\ 'virt_text_win_col': virtcol('.') - 1
|
||||
\ })
|
||||
|
||||
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, {
|
||||
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}),
|
||||
\ 'virt_text_win_col': virtcol('.')
|
||||
\ })
|
||||
elseif s:ghost_text_vim
|
||||
let l:new_suffix = s:content[0]
|
||||
if !empty(l:new_suffix)
|
||||
call prop_add(s:pos_y, s:pos_x + 1, {
|
||||
\ 'type': s:hlgroup_hint,
|
||||
\ 'text': l:new_suffix
|
||||
\ })
|
||||
endif
|
||||
for line in s:content[1:]
|
||||
call prop_add(s:pos_y, 0, {
|
||||
\ 'type': s:hlgroup_hint,
|
||||
\ 'text': line,
|
||||
\ 'text_padding_left': s:get_indent(line),
|
||||
\ 'text_align': 'below'
|
||||
\ })
|
||||
endfor
|
||||
if !empty(l:info)
|
||||
call prop_add(s:pos_y, 0, {
|
||||
\ 'type': s:hlgroup_info,
|
||||
\ 'text': l:info,
|
||||
\ 'text_padding_left': col('$'),
|
||||
\ 'text_wrap': 'truncate'
|
||||
\ })
|
||||
endif
|
||||
endif
|
||||
|
||||
" setup accept shortcuts
|
||||
inoremap <buffer> <Tab> <C-O>:call llama#fim_accept(v:false)<CR>
|
||||
inoremap <buffer> <S-Tab> <C-O>:call llama#fim_accept(v:true)<CR>
|
||||
|
||||
let s:hint_shown = v:true
|
||||
endfunction
|
||||
|
||||
function! s:fim_on_exit(job_id, exit_code, event = v:null)
|
||||
if a:exit_code != 0
|
||||
echom "Job failed with exit code: " . a:exit_code
|
||||
endif
|
||||
|
||||
let s:current_job = v:null
|
||||
endfunction
|
|
@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
|
|||
return true;
|
||||
}
|
||||
|
||||
struct llava_embd_batch {
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
|
||||
|
@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
float * embd = image_embed->embed+i*n_embd;
|
||||
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
|
||||
if (llama_decode(ctx_llama, llava_batch.batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -97,7 +97,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
|
||||
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -89,8 +89,8 @@ int main(int argc, char ** argv) {
|
|||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// eval the prompt
|
||||
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
||||
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
||||
|
||||
for (int s = 1; s < W + G + 1; ++s) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
||||
|
|
|
@ -89,8 +89,8 @@ int main(int argc, char ** argv){
|
|||
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
||||
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
|
|
|
@ -529,7 +529,7 @@ int main(int argc, char ** argv) {
|
|||
int enc_input_size = embd_inp.size();
|
||||
llama_token * enc_input_buf = embd_inp.data();
|
||||
|
||||
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
|
||||
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
@ -649,7 +649,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -132,6 +132,7 @@ struct slot_params {
|
|||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
@ -174,6 +175,8 @@ struct server_slot {
|
|||
std::vector<llama_token> prompt_tokens;
|
||||
std::vector<llama_token> extra_tokens;
|
||||
|
||||
size_t last_nl_pos = 0;
|
||||
|
||||
std::string generated_text;
|
||||
std::vector<llama_token> cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
@ -216,6 +219,7 @@ struct server_slot {
|
|||
SLT_DBG(*this, "%s", "\n");
|
||||
|
||||
n_prompt_tokens = 0;
|
||||
last_nl_pos = 0;
|
||||
generated_text = "";
|
||||
has_new_line = false;
|
||||
truncated = false;
|
||||
|
@ -861,6 +865,7 @@ struct server_context {
|
|||
slot.params.stream = json_value(data, "stream", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
||||
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
|
@ -879,7 +884,7 @@ struct server_context {
|
|||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
|
@ -1130,15 +1135,50 @@ struct server_context {
|
|||
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
||||
}
|
||||
|
||||
if (slot.has_new_line) {
|
||||
// if we have already seen a new line, we stop after a certain time limit
|
||||
if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
|
||||
(ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
||||
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
|
||||
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
||||
}
|
||||
|
||||
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
|
||||
if (slot.params.n_indent > 0) {
|
||||
// check the current indentation
|
||||
// TODO: improve by not doing it more than once for each new line
|
||||
if (slot.last_nl_pos > 0) {
|
||||
size_t pos = slot.last_nl_pos;
|
||||
|
||||
int n_indent = 0;
|
||||
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
|
||||
n_indent++;
|
||||
pos++;
|
||||
}
|
||||
|
||||
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
|
||||
slot.stopped_limit = true;
|
||||
slot.has_next_token = false;
|
||||
|
||||
// cut the last line
|
||||
slot.generated_text.erase(pos, std::string::npos);
|
||||
|
||||
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
|
||||
}
|
||||
}
|
||||
|
||||
// find the next new line
|
||||
{
|
||||
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
slot.last_nl_pos = pos + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if there is a new line in the generated text
|
||||
if (result.text_to_send.find('\n') != std::string::npos) {
|
||||
slot.has_new_line = true;
|
||||
|
@ -2124,17 +2164,10 @@ struct server_context {
|
|||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
|
@ -2167,8 +2200,6 @@ struct server_context {
|
|||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
|
||||
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
|
@ -2220,8 +2251,6 @@ struct server_context {
|
|||
|
||||
// there is no common part left
|
||||
slot.n_past = 0;
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
|
@ -2249,6 +2278,13 @@ struct server_context {
|
|||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
|
@ -2287,7 +2323,6 @@ struct server_context {
|
|||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
|
|
@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// prepare a batch for the prompt
|
||||
|
||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
|
||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||
|
||||
// main loop
|
||||
|
||||
|
@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
|
|||
fflush(stdout);
|
||||
|
||||
// prepare the next batch with the sampled token
|
||||
batch = llama_batch_get_one(&new_token_id, 1, n_pos, 0);
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
|
25
ggml/include/ggml-amx.h
Normal file
25
ggml/include/ggml-amx.h
Normal file
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// buffer_type API
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
|
||||
|
||||
GGML_API bool ggml_backend_is_amx(ggml_backend_t backend);
|
||||
|
||||
// backend API
|
||||
GGML_API ggml_backend_t ggml_backend_amx_init(void);
|
||||
|
||||
GGML_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);
|
||||
|
||||
GGML_API ggml_backend_reg_t ggml_backend_amx_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -34,6 +34,8 @@ extern "C" {
|
|||
*/
|
||||
#define GGML_CANN_MAX_DEVICES 16
|
||||
|
||||
GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
|
||||
|
||||
/**
|
||||
* @brief Initializes the CANN backend for a specified device.
|
||||
*
|
||||
|
|
|
@ -19,6 +19,8 @@ extern "C" {
|
|||
// backend API
|
||||
GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
|
||||
|
||||
GGML_API bool ggml_backend_is_sycl(ggml_backend_t backend);
|
||||
|
||||
// devide buffer
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
|
||||
|
||||
|
@ -29,14 +31,19 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const fl
|
|||
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
|
||||
|
||||
GGML_API void ggml_backend_sycl_print_sycl_devices(void);
|
||||
GGML_API void ggml_sycl_get_gpu_list(int *id_list, int max_len);
|
||||
GGML_API void ggml_sycl_get_device_description(int device, char *description, size_t description_size);
|
||||
GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
|
||||
GGML_API void ggml_backend_sycl_get_device_description(int device,
|
||||
char *description,
|
||||
size_t description_size);
|
||||
GGML_API int ggml_backend_sycl_get_device_count();
|
||||
GGML_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
|
||||
|
||||
// SYCL doesn't support registering host memory, keep here for reference
|
||||
// GGML_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
|
||||
// GGML_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
|
||||
|
||||
GGML_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -2494,6 +2494,7 @@ extern "C" {
|
|||
GGML_API int ggml_cpu_has_avx512_vbmi(void);
|
||||
GGML_API int ggml_cpu_has_avx512_vnni(void);
|
||||
GGML_API int ggml_cpu_has_avx512_bf16(void);
|
||||
GGML_API int ggml_cpu_has_amx_int8 (void);
|
||||
GGML_API int ggml_cpu_has_fma (void);
|
||||
GGML_API int ggml_cpu_has_neon (void);
|
||||
GGML_API int ggml_cpu_has_sve (void);
|
||||
|
|
453
ggml/src/ggml-amx.cpp
Normal file
453
ggml/src/ggml-amx.cpp
Normal file
|
@ -0,0 +1,453 @@
|
|||
#include "ggml-amx.h"
|
||||
#include "ggml-amx/common.h"
|
||||
#include "ggml-amx/mmq.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#if defined(__gnu_linux__)
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
#if defined(__AMX_INT8__)
|
||||
|
||||
// AMX buffer interface
|
||||
static const char * ggml_backend_amx_buffer_get_name(ggml_backend_buffer_t buffer) {
|
||||
return "AMX";
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
free(buffer->context);
|
||||
}
|
||||
|
||||
static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
return (void *)(buffer->context);
|
||||
}
|
||||
|
||||
static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
memset((char *)tensor->data + offset, value, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
if (qtype_has_amx_kernels(tensor->type)) {
|
||||
ggml_backend_amx_convert_weight(tensor, data, offset, size);
|
||||
} else {
|
||||
memcpy((char *)tensor->data + offset, data, size);
|
||||
}
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
|
||||
memcpy(data, (const char *)tensor->data + offset, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
if (ggml_backend_buffer_is_host(src->buffer)) {
|
||||
if (qtype_has_amx_kernels(src->type)) {
|
||||
ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
|
||||
} else {
|
||||
memcpy(dst->data, src->data, ggml_nbytes(src));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
memset(buffer->context, value, buffer->size);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
|
||||
/* .get_name = */ ggml_backend_amx_buffer_get_name,
|
||||
/* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
|
||||
/* .get_base = */ ggml_backend_amx_buffer_get_base,
|
||||
/* .init_tensor = */ NULL, // no initialization required
|
||||
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
|
||||
/* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
|
||||
/* .clear = */ ggml_backend_amx_buffer_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "AMX";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
|
||||
if (data == NULL) {
|
||||
fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return TENSOR_ALIGNMENT;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
|
||||
return ggml_backend_amx_get_alloc_size(tensor);
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
|
||||
static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_amx_buffer_type_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_amx_buffer_type_is_host,
|
||||
},
|
||||
/* .device = */ NULL,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_buffer_type_amx;
|
||||
}
|
||||
|
||||
// backend interface
|
||||
|
||||
static const char * ggml_backend_amx_name(ggml_backend_t backend) {
|
||||
return "AMX";
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_amx_free(ggml_backend_t backend) {
|
||||
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
|
||||
delete ctx;
|
||||
delete backend;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_amx_get_default_buffer_type(ggml_backend_t backend) {
|
||||
return ggml_backend_amx_buffer_type();
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
ggml_backend_amx_mul_mat(ctx, node);
|
||||
break;
|
||||
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
break;
|
||||
|
||||
default:
|
||||
fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static struct ggml_backend_i ggml_backend_amx_i = {
|
||||
/* .get_name = */ ggml_backend_amx_name,
|
||||
/* .free = */ ggml_backend_amx_free,
|
||||
/* .get_default_buffer_type = */ ggml_backend_amx_get_default_buffer_type,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
/* .graph_plan_free = */ NULL,
|
||||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_amx_graph_compute,
|
||||
/* .supports_op = */ NULL,
|
||||
/* .supports_buft = */ NULL,
|
||||
/* .offload_op = */ NULL,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_amx_guid() {
|
||||
static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
|
||||
return &guid;
|
||||
}
|
||||
|
||||
#define ARCH_GET_XCOMP_PERM 0x1022
|
||||
#define ARCH_REQ_XCOMP_PERM 0x1023
|
||||
#define XFEATURE_XTILECFG 17
|
||||
#define XFEATURE_XTILEDATA 18
|
||||
|
||||
static bool ggml_amx_init() {
|
||||
#if defined(__gnu_linux__)
|
||||
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
|
||||
fprintf(stderr, "AMX is not ready to be used!\n");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
#elif defined(_WIN32)
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_amx_init() {
|
||||
|
||||
// invoke a Linux system call to request access to AMX features
|
||||
ggml_amx_init();
|
||||
|
||||
// backend context
|
||||
ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
|
||||
|
||||
// ggml amx backend
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
/* .guid = */ ggml_backend_amx_guid(),
|
||||
/* .interface = */ ggml_backend_amx_i,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
return backend;
|
||||
}
|
||||
|
||||
bool ggml_backend_is_amx(ggml_backend_t backend) {
|
||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
|
||||
}
|
||||
|
||||
void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
|
||||
GGML_ASSERT(ggml_backend_is_amx(backend_amx));
|
||||
|
||||
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
|
||||
ctx->n_threads = n_threads;
|
||||
}
|
||||
|
||||
// device interface
|
||||
|
||||
static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
|
||||
return "AMX";
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
|
||||
return "Intel Advanced Matrix Extensions";
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
// TODO
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
|
||||
return GGML_BACKEND_DEVICE_TYPE_CPU;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_amx_device_get_name(dev);
|
||||
props->description = ggml_backend_amx_device_get_description(dev);
|
||||
props->type = ggml_backend_amx_device_get_type(dev);
|
||||
ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
// `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
|
||||
props->caps = {
|
||||
/* .async = */ false,
|
||||
/* .host_buffer = */ false,
|
||||
/* .buffer_from_host_ptr = */ false,
|
||||
/* .events = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
return ggml_backend_amx_init();
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
return ggml_backend_amx_buffer_type();
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
|
||||
// handle only 2d gemm for now
|
||||
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
|
||||
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
|
||||
};
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
return true;
|
||||
|
||||
case GGML_OP_MUL_MAT: {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
const int64_t ne0 = op->ne[0];
|
||||
|
||||
bool is_training = src0->grad || src1->grad;
|
||||
|
||||
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
|
||||
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
|
||||
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
|
||||
|
||||
bool can_use_amx =
|
||||
is_contiguous_2d(src0) && // src0 must be contiguous
|
||||
is_contiguous_2d(src1) && // src1 must be contiguous
|
||||
!is_training && // inference only
|
||||
src1->type == GGML_TYPE_F32 && // src1 must be float32
|
||||
has_amx_kernels && // with amx kernel impls
|
||||
ne0 % (TILE_N * 2) == 0; // out_features is 32x
|
||||
|
||||
return can_use_amx;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
|
||||
/* .get_name = */ ggml_backend_amx_device_get_name,
|
||||
/* .get_description = */ ggml_backend_amx_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_amx_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_amx_device_get_type,
|
||||
/* .get_props = */ ggml_backend_amx_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_amx_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ NULL,
|
||||
/* .supports_op = */ ggml_backend_amx_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_amx_device_supports_buft,
|
||||
/* .offload_op = */ NULL,
|
||||
/* .event_new = */ NULL,
|
||||
/* .event_free = */ NULL,
|
||||
/* .event_synchronize = */ NULL,
|
||||
};
|
||||
|
||||
// backend reg interface
|
||||
|
||||
static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
|
||||
return "AMX";
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
return 1;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
|
||||
static ggml_backend_device ggml_backend_amx_device = {
|
||||
/* .iface = */ ggml_backend_amx_device_i,
|
||||
/* .reg = */ reg,
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
return &ggml_backend_amx_device;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(index);
|
||||
}
|
||||
|
||||
static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
||||
return (void *)ggml_backend_amx_set_n_threads;
|
||||
}
|
||||
return NULL;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(name);
|
||||
}
|
||||
|
||||
static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
|
||||
/* .get_name = */ ggml_backend_amx_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
|
||||
/* .get_device = */ ggml_backend_amx_reg_get_device,
|
||||
/* .get_proc_address = */ ggml_backend_amx_get_proc_address,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_amx_reg(void) {
|
||||
static struct ggml_backend_reg ggml_backend_amx_reg = {
|
||||
/* .iface = */ ggml_backend_amx_reg_i,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_amx_reg;
|
||||
}
|
||||
|
||||
#else // if defined(__AMX_INT8__)
|
||||
|
||||
ggml_backend_t ggml_backend_amx_init(void) {
|
||||
fprintf(stderr, "GGML is not compiled with AMX support!\n");
|
||||
return ggml_backend_t{};
|
||||
}
|
||||
|
||||
void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
|
||||
fprintf(stderr, "GGML is not compiled with AMX support!\n");
|
||||
|
||||
GGML_UNUSED(backend_amx);
|
||||
GGML_UNUSED(n_threads);
|
||||
}
|
||||
|
||||
#endif
|
93
ggml/src/ggml-amx/common.h
Normal file
93
ggml/src/ggml-amx/common.h
Normal file
|
@ -0,0 +1,93 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu-impl.h" // <immintrin.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#define TILE_M 16
|
||||
#define TILE_N 16
|
||||
#define TILE_K 32
|
||||
#define VNNI_BLK 4
|
||||
|
||||
#define AMX_BLK_SIZE 32
|
||||
|
||||
#define TMM0 0
|
||||
#define TMM1 1
|
||||
#define TMM2 2
|
||||
#define TMM3 3
|
||||
#define TMM4 4
|
||||
#define TMM5 5
|
||||
#define TMM6 6
|
||||
#define TMM7 7
|
||||
|
||||
// parallel routines
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) { return (x + y - 1) / y; }
|
||||
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
// onednn partition pattern
|
||||
T& n_my = n_end;
|
||||
if (nth <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
T n1 = div_up(n, nth);
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * nth;
|
||||
n_my = ith < T1 ? n1 : n2;
|
||||
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||
}
|
||||
n_end += n_start;
|
||||
#else
|
||||
// pytorch aten partition pattern
|
||||
T n_my = div_up(n, nth);
|
||||
n_start = ith * n_my;
|
||||
n_end = std::min(n_start + n_my, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int nth, int n, const func_t& f) {
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
//int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
|
||||
GGML_UNUSED(nth);
|
||||
#endif
|
||||
}
|
||||
|
||||
// quantized types that have AMX support
|
||||
inline bool qtype_has_amx_kernels(const enum ggml_type type) {
|
||||
// TODO: fix padding for vnni format
|
||||
return (type == GGML_TYPE_Q4_0) ||
|
||||
(type == GGML_TYPE_Q4_1);
|
||||
//(type == GGML_TYPE_Q8_0) ||
|
||||
//(type == GGML_TYPE_Q4_K) ||
|
||||
//(type == GGML_TYPE_Q5_K) ||
|
||||
//(type == GGML_TYPE_Q6_K) ||
|
||||
//(type == GGML_TYPE_IQ4_XS);
|
||||
}
|
||||
|
||||
// ggml backend context
|
||||
struct ggml_backend_amx_context {
|
||||
int n_threads = GGML_DEFAULT_N_THREADS;
|
||||
std::unique_ptr<char[]> work_data;
|
||||
size_t work_size = 0;
|
||||
};
|
2509
ggml/src/ggml-amx/mmq.cpp
Normal file
2509
ggml/src/ggml-amx/mmq.cpp
Normal file
File diff suppressed because it is too large
Load diff
17
ggml/src/ggml-amx/mmq.h
Normal file
17
ggml/src/ggml-amx/mmq.h
Normal file
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor);
|
||||
|
||||
void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
|
||||
void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -329,7 +329,6 @@ bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type
|
|||
if (backend->device) {
|
||||
return ggml_backend_dev_supports_buft(backend->device, buft);
|
||||
}
|
||||
|
||||
return backend->iface.supports_buft(backend, buft);
|
||||
}
|
||||
|
||||
|
@ -538,6 +537,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
|
|||
#include "ggml-metal.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_SYCL
|
||||
#include "ggml-sycl.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
#include "ggml-vulkan.h"
|
||||
#endif
|
||||
|
@ -550,6 +553,18 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
|
|||
#include "ggml-rpc.h"
|
||||
#endif
|
||||
|
||||
#ifndef __AMX_INT8__
|
||||
#undef GGML_USE_AMX
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_AMX
|
||||
# include "ggml-amx.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CANN
|
||||
#include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
struct ggml_backend_registry {
|
||||
std::vector<ggml_backend_reg_t> backends;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
|
@ -561,6 +576,9 @@ struct ggml_backend_registry {
|
|||
#ifdef GGML_USE_METAL
|
||||
register_backend(ggml_backend_metal_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_SYCL
|
||||
register_backend(ggml_backend_sycl_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_VULKAN
|
||||
register_backend(ggml_backend_vk_reg());
|
||||
#endif
|
||||
|
@ -570,8 +588,14 @@ struct ggml_backend_registry {
|
|||
#ifdef GGML_USE_RPC
|
||||
register_backend(ggml_backend_rpc_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_AMX
|
||||
register_backend(ggml_backend_amx_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_CANN
|
||||
register_backend(ggml_backend_cann_reg());
|
||||
#endif
|
||||
|
||||
// TODO: sycl, kompute, cann
|
||||
// TODO: kompute
|
||||
|
||||
register_backend(ggml_backend_cpu_reg());
|
||||
}
|
||||
|
@ -2250,6 +2274,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
|||
sched->backends[b] = backends[b];
|
||||
sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
|
||||
GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
|
||||
|
||||
if (sched->n_copies > 1) {
|
||||
for (int c = 0; c < sched->n_copies; c++) {
|
||||
sched->events[b][c] = ggml_backend_event_new(backends[b]->device);
|
||||
|
|
|
@ -39,6 +39,8 @@
|
|||
|
||||
#include "ggml-common.h"
|
||||
|
||||
#define GGML_CANN_NAME "CANN"
|
||||
|
||||
/**
|
||||
* @brief Handles CANN errors by printing an error message and aborting.
|
||||
*
|
||||
|
@ -851,13 +853,6 @@ static void ggml_backend_cann_buffer_set_tensor(
|
|||
void *transform_buffer = malloc(size);
|
||||
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
||||
|
||||
#ifndef NDEBUG
|
||||
void *check_buffer = malloc(size);
|
||||
ggml_backend_cann_transform_back(tensor, transform_buffer,
|
||||
check_buffer);
|
||||
GGML_ASSERT(memcmp(data, check_buffer, size) == 0);
|
||||
free(check_buffer);
|
||||
#endif
|
||||
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
|
||||
transform_buffer, size,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
|
@ -969,7 +964,7 @@ static void ggml_backend_cann_buffer_clear(
|
|||
* This structure defines function pointers to operations that can be performed
|
||||
* on a CANN buffer within the backend.
|
||||
*/
|
||||
static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
||||
static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_buffer_get_name,
|
||||
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
||||
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
||||
|
@ -1105,19 +1100,25 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
|
|||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Interface for managing CANN buffer types in the GGML backend.
|
||||
*
|
||||
* Provides function pointers for allocating, querying properties, and managing
|
||||
* memory for CANN buffer types in the GGML backend.
|
||||
*/
|
||||
static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
|
||||
static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_buffer_type_name,
|
||||
/* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
|
||||
/* .is_host = */ NULL,
|
||||
/* .is_host = */ ggml_backend_cann_buffer_type_is_host,
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -1148,7 +1149,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
|
|||
for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
|
||||
ggml_backend_cann_buffer_types[i] = {
|
||||
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
||||
/* .device = */ nullptr,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
||||
/* .context = */
|
||||
new ggml_backend_cann_buffer_type_context{
|
||||
i, "CANN" + std::to_string(i)},
|
||||
|
@ -1264,7 +1265,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
|||
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||
},
|
||||
/* .device = */ nullptr,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
|
@ -1511,13 +1512,6 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
|||
void *transform_buffer = malloc(size);
|
||||
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
||||
|
||||
#ifndef NDEBUG
|
||||
void *check_buffer = malloc(size);
|
||||
ggml_backend_cann_transform_back(tensor, transform_buffer,
|
||||
check_buffer);
|
||||
GGML_ASSERT(memcmp(data, check_buffer, size));
|
||||
free(check_buffer);
|
||||
#endif
|
||||
ACL_CHECK(aclrtMemcpyAsync(
|
||||
(char *)tensor->data + offset, size, transform_buffer, size,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
|
||||
|
@ -1692,7 +1686,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
|||
* @return bool Returns true if the operation is supported by the backend,
|
||||
* otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
|
||||
static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
const ggml_tensor* op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
|
@ -1783,7 +1777,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
|
|||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1801,31 +1795,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
|||
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the CANN backend supports a specific backend buffer type.
|
||||
*
|
||||
* This function determines whether the CANN backend supports the given backend
|
||||
* buffer type by comparing the device context of the backend and buffer type.
|
||||
* It returns true if the devices are same between the backend context and
|
||||
* buffer type context.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend.
|
||||
* @param buft Pointer to the backend buffer type to check.
|
||||
* @return bool Returns true if the CANN backend supports the buffer type,
|
||||
* otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_cann_supports_buft(
|
||||
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||
if (ggml_backend_buft_is_cann(buft)) {
|
||||
ggml_backend_cann_context * cann_ctx =
|
||||
(ggml_backend_cann_context *)backend->context;
|
||||
ggml_backend_cann_buffer_type_context * buft_ctx =
|
||||
(ggml_backend_cann_buffer_type_context *)buft->context;
|
||||
return buft_ctx->device == cann_ctx->device;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Determines if a tensor operation should be offloaded to the CANN
|
||||
* backend.
|
||||
|
@ -1840,54 +1809,14 @@ static bool ggml_backend_cann_supports_buft(
|
|||
* @return bool Returns true if the operation should be offloaded, otherwise
|
||||
* false.
|
||||
*/
|
||||
static bool ggml_backend_cann_offload_op(ggml_backend_t backend,
|
||||
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
|
||||
const ggml_tensor* op) {
|
||||
const int min_batch_size = 32;
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(dev);
|
||||
|
||||
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Creates a new event for the CANN backend.
|
||||
*
|
||||
* This function initializes a new event for the CANN backend by setting the
|
||||
* device and creating an ACL runtime event. The created event is then wrapped
|
||||
* in a ggml_backend_event structure and returned.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend.
|
||||
* @return ggml_backend_event_t Returns a pointer to the new event structure.
|
||||
*/
|
||||
static ggml_backend_event_t ggml_backend_cann_event_new(
|
||||
ggml_backend_t backend) {
|
||||
ggml_backend_cann_context* cann_ctx =
|
||||
(ggml_backend_cann_context*)backend->context;
|
||||
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
|
||||
aclrtEvent event;
|
||||
ACL_CHECK(aclrtCreateEvent(&event));
|
||||
|
||||
return new ggml_backend_event{
|
||||
/* .device = */ nullptr,
|
||||
/* .context = */ event,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Frees a CANN backend event.
|
||||
*
|
||||
* This function destroys the ACL runtime event associated with the given CANN
|
||||
* backend event and then deletes the event structure itself.
|
||||
*
|
||||
* @param event Pointer to the event structure to be freed.
|
||||
*/
|
||||
static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
|
||||
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
|
||||
|
||||
delete event;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Records an event on the CANN backend stream.
|
||||
*
|
||||
|
@ -1924,17 +1853,6 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronizes the given event on the CANN backend.
|
||||
*
|
||||
* This function waits for the specified event to complete on the ACL runtime.
|
||||
*
|
||||
* @param event Pointer to the event structure to be synchronized.
|
||||
*/
|
||||
static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
|
||||
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Structure defining the interface for the CANN backend.
|
||||
*
|
||||
|
@ -1942,7 +1860,7 @@ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
|
|||
* supported by the CANN backend, including name retrieval, memory
|
||||
* management, tensor operations, synchronization, and event handling.
|
||||
*/
|
||||
static ggml_backend_i ggml_backend_cann_interface = {
|
||||
static const ggml_backend_i ggml_backend_cann_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_name,
|
||||
/* .free = */ ggml_backend_cann_free,
|
||||
/* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
|
||||
|
@ -1955,9 +1873,9 @@ static ggml_backend_i ggml_backend_cann_interface = {
|
|||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_cann_graph_compute,
|
||||
/* .supports_op = */ ggml_backend_cann_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_cann_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_cann_offload_op,
|
||||
/* .supports_op = */ NULL, // moved to device
|
||||
/* .supports_buft = */ NULL, // moved to device
|
||||
/* .offload_op = */ NULL, // moved to device
|
||||
/* .event_record = */ ggml_backend_cann_event_record,
|
||||
/* .event_wait = */ ggml_backend_cann_event_wait,
|
||||
};
|
||||
|
@ -1976,6 +1894,234 @@ static ggml_guid_t ggml_backend_cann_guid() {
|
|||
return &guid;
|
||||
}
|
||||
|
||||
// backend device
|
||||
struct ggml_backend_cann_device_context {
|
||||
int device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
return ctx->name.c_str();
|
||||
}
|
||||
|
||||
static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
ggml_backend_cann_get_device_memory(ctx->device, free, total);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
|
||||
GGML_UNUSED(dev);
|
||||
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
|
||||
}
|
||||
|
||||
static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_cann_device_get_name(dev);
|
||||
props->description = ggml_backend_cann_device_get_description(dev);
|
||||
props->type = ggml_backend_cann_device_get_type(dev);
|
||||
ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
|
||||
|
||||
props->caps = {
|
||||
/* .async = */ false,
|
||||
/* .host_buffer = */ host_buffer,
|
||||
/* .buffer_from_host_ptr = */ false,
|
||||
/* .events = */ true,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
GGML_UNUSED(params);
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
return ggml_backend_cann_init(ctx->device);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the CANN backend supports a specific backend buffer type.
|
||||
*
|
||||
* This function determines whether the CANN backend supports the given backend
|
||||
* buffer type by comparing the device context of the backend and buffer type.
|
||||
* It returns true if the devices are same between the backend context and
|
||||
* buffer type context.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend.
|
||||
* @param buft Pointer to the backend buffer type to check.
|
||||
* @return bool Returns true if the CANN backend supports the buffer type,
|
||||
* otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_cann_supports_buft(
|
||||
ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
if (ggml_backend_buft_is_cann(buft)) {
|
||||
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
ggml_backend_cann_buffer_type_context * buft_ctx =
|
||||
(ggml_backend_cann_buffer_type_context *)buft->context;
|
||||
return buft_ctx->device == dev_ctx->device;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
return ggml_backend_cann_buffer_type(ctx->device);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
|
||||
GGML_UNUSED(dev);
|
||||
return ggml_backend_cann_host_buffer_type();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Creates a new event for the CANN backend device.
|
||||
*
|
||||
* This function initializes a new event for the CANN backend by setting the
|
||||
* device and creating an ACL runtime event. The created event is then wrapped
|
||||
* in a ggml_backend_event structure and returned.
|
||||
*
|
||||
* @param backend Pointer to the CANN backend.
|
||||
* @return ggml_backend_event_t Returns a pointer to the new event structure.
|
||||
*/
|
||||
static ggml_backend_event_t ggml_backend_cann_device_event_new(
|
||||
ggml_backend_dev_t dev) {
|
||||
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
||||
|
||||
ggml_cann_set_device(dev_ctx->device);
|
||||
|
||||
aclrtEvent event;
|
||||
ACL_CHECK(aclrtCreateEvent(&event));
|
||||
|
||||
return new ggml_backend_event{
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
|
||||
/* .context = */ event,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Frees a CANN backend event.
|
||||
*
|
||||
* This function destroys the ACL runtime event associated with the given CANN
|
||||
* backend event and then deletes the event structure itself.
|
||||
*
|
||||
* @param event Pointer to the event structure to be freed.
|
||||
*/
|
||||
static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
||||
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
|
||||
|
||||
delete event;
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Synchronizes the given event on the CANN backend.
|
||||
*
|
||||
* This function waits for the specified event to complete on the ACL runtime.
|
||||
*
|
||||
* @param event Pointer to the event structure to be synchronized.
|
||||
*/
|
||||
static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
||||
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const ggml_backend_device_i ggml_backend_cann_device_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_device_get_name,
|
||||
/* .get_description = */ ggml_backend_cann_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_cann_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_cann_device_get_type,
|
||||
/* .get_props = */ ggml_backend_cann_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_cann_device_init, // called for every card
|
||||
/* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
|
||||
/* .buffer_from_host_ptr = */ NULL, // not supported for CANN
|
||||
/* .supports_op = */ ggml_backend_cann_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_cann_supports_buft,
|
||||
/* .offload_op = */ ggml_backend_cann_offload_op,
|
||||
/* .event_new = */ ggml_backend_cann_device_event_new,
|
||||
/* .event_free = */ ggml_backend_cann_device_event_free,
|
||||
/* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
|
||||
};
|
||||
|
||||
|
||||
// backend reg
|
||||
struct ggml_backend_cann_reg_context {
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
|
||||
GGML_UNUSED(reg);
|
||||
return GGML_CANN_NAME;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
||||
return ctx->devices.size();
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
||||
GGML_ASSERT(index < ctx->devices.size());
|
||||
return ctx->devices[index];
|
||||
}
|
||||
|
||||
static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(name);
|
||||
// reserved for future use
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
|
||||
/* .get_name = */ ggml_backend_cann_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
|
||||
/* .get_device_get = */ ggml_backend_cann_reg_get_device,
|
||||
/* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
|
||||
};
|
||||
|
||||
// backend registry, called only once for cann backend
|
||||
ggml_backend_reg_t ggml_backend_cann_reg() {
|
||||
static ggml_backend_reg reg;
|
||||
static bool initialized = false;
|
||||
|
||||
{
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (!initialized) {
|
||||
aclInit(nullptr);
|
||||
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
|
||||
|
||||
for (int i = 0; i < ggml_cann_info().device_count; i++) {
|
||||
ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
|
||||
dev_ctx->description = aclrtGetSocName();
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
|
||||
ggml_cann_set_device(i);
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
/* .interface = */ ggml_backend_cann_device_interface,
|
||||
/* .reg = */ ®,
|
||||
/* .context = */ dev_ctx
|
||||
};
|
||||
ctx->devices.push_back(dev);
|
||||
}
|
||||
|
||||
reg = ggml_backend_reg {
|
||||
/* .interface = */ ggml_backend_cann_reg_interface,
|
||||
/* .context = */ ctx
|
||||
};
|
||||
}
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
|
||||
return ®
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
||||
aclInit(nullptr);
|
||||
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
||||
|
@ -1992,7 +2138,7 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
|||
ggml_backend_t cann_backend =
|
||||
new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
|
||||
/* .interface = */ ggml_backend_cann_interface,
|
||||
/* .device = */ nullptr,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
||||
/* .context = */ ctx};
|
||||
|
||||
return cann_backend;
|
||||
|
|
|
@ -1151,7 +1151,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
|
||||
char * src_ptr = (char *) src->data;
|
||||
const char * src_ptr = (const char *) src->data;
|
||||
char * dst_ptr = (char *) dst;
|
||||
|
||||
const int64_t ne0 = src->ne[0];
|
||||
|
@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||
const enum ggml_type type = src->type;
|
||||
const int64_t ts = ggml_type_size(type);
|
||||
const int64_t bs = ggml_blck_size(type);
|
||||
int64_t i1_diff = i1_high - i1_low;
|
||||
const int64_t i1_diff = i1_high - i1_low;
|
||||
|
||||
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
||||
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
||||
|
@ -1479,13 +1479,18 @@ static void ggml_cuda_op_mul_mat(
|
|||
if (src0_is_contiguous) {
|
||||
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
|
||||
} else {
|
||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
|
||||
// If src0 is not contiguous it will be copied to a temporary buffer.
|
||||
// This buffer needs to be cleared entirely because multiple regions will function as padding.
|
||||
const size_t nbytes_data = ggml_nbytes(src0);
|
||||
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
|
||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
|
||||
}
|
||||
|
||||
// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
|
||||
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
|
||||
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
|
||||
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
|
||||
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
|
||||
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
|
||||
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
|
||||
}
|
||||
|
||||
|
@ -3145,7 +3150,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_ROPE:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
|
|
|
@ -92,8 +92,8 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
const int64_t OW = dst->ne[1];
|
||||
|
||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
const int64_t batch = src1->ne[3];
|
||||
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
||||
const int64_t batch = src1->ne[is_2D ? 3 : 2];
|
||||
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
if(dst->type == GGML_TYPE_F16) {
|
||||
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
|
||||
|
|
|
@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q(
|
|||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
|
||||
const int64_t nb01 = src0->nb[1];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
|
@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q(
|
|||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
const int64_t stride00 = nb01 / ggml_type_size(src0->type);
|
||||
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
||||
|
|
|
@ -241,6 +241,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
||||
|
@ -272,6 +274,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
|
||||
GGML_METAL_KERNEL_TYPE_COUNT
|
||||
};
|
||||
|
@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
||||
|
@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
}
|
||||
|
||||
[metal_library release];
|
||||
|
@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_1D:
|
||||
case GGML_OP_POOL_2D:
|
||||
return false;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
|
@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
|
|||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
@ -2574,11 +2584,23 @@ static void ggml_metal_encode_node(
|
|||
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
||||
|
||||
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
||||
case GGML_TYPE_F32: {
|
||||
pipeline = (is_gt_mttpt ?
|
||||
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
|
||||
:
|
||||
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
pipeline = (is_gt_mttpt ?
|
||||
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
|
||||
:
|
||||
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
|
@ -2597,7 +2619,19 @@ static void ggml_metal_encode_node(
|
|||
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
||||
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
||||
|
||||
if (is_gt_mttpt) {
|
||||
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
||||
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
||||
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
||||
|
||||
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
||||
|
||||
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} else {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_UPSCALE:
|
||||
{
|
||||
|
@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
|
|||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_POOL_2D:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
|
||||
|
||||
const int32_t * opts = dst->op_params;
|
||||
enum ggml_op_pool op = opts[0];
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32: {
|
||||
switch(op) {
|
||||
case GGML_OP_POOL_AVG:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
|
||||
case GGML_OP_POOL_MAX:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
} break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
const int32_t k0 = opts[1];
|
||||
const int32_t k1 = opts[2];
|
||||
const int32_t s0 = opts[3];
|
||||
const int32_t s1 = opts[4];
|
||||
const int32_t p0 = opts[5];
|
||||
const int32_t p1 = opts[6];
|
||||
|
||||
const int64_t IH = src0->ne[1];
|
||||
const int64_t IW = src0->ne[0];
|
||||
|
||||
const int64_t N = dst->ne[3];
|
||||
const int64_t OC = dst->ne[2];
|
||||
const int64_t OH = dst->ne[1];
|
||||
const int64_t OW = dst->ne[0];
|
||||
|
||||
const int64_t parallel_elements = N * OC * OH * OW;
|
||||
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
||||
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
||||
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
||||
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
||||
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
||||
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
||||
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
||||
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
||||
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
||||
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
||||
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
||||
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
|
|
|
@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
|
|||
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
||||
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
||||
|
||||
typedef void (im2col_ext_t)(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
constant int32_t & ofs0,
|
||||
constant int32_t & ofs1,
|
||||
constant int32_t & IW,
|
||||
constant int32_t & IH,
|
||||
constant int32_t & CHW,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int32_t & d0,
|
||||
constant int32_t & d1,
|
||||
constant int32_t & N,
|
||||
constant int32_t & KH,
|
||||
constant int32_t & KW,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_im2col_ext(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
constant int32_t & ofs0,
|
||||
constant int32_t & ofs1,
|
||||
constant int32_t & IW,
|
||||
constant int32_t & IH,
|
||||
constant int32_t & CHW,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int32_t & d0,
|
||||
constant int32_t & d1,
|
||||
constant int32_t & N,
|
||||
constant int32_t & KH,
|
||||
constant int32_t & KW,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
||||
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
|
||||
|
||||
const int32_t d = tgpig[0] / CHW;
|
||||
const int32_t chw = tgpig[0] % CHW;
|
||||
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
||||
const int32_t HW = tgpig[0] % KHW;
|
||||
|
||||
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
||||
if (tpitg_0 >= N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t tpitg_1 = HW / KW;
|
||||
const int32_t tpitg_2 = HW % KW;
|
||||
|
||||
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
|
||||
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
|
||||
|
||||
const int32_t offset_dst =
|
||||
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
|
||||
|
||||
device T * pdst = (device T *) (dst);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
pdst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
|
||||
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
||||
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
|
@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
|
|||
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
||||
|
||||
kernel void kernel_pool_2d_max_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int32_t & k0,
|
||||
constant int32_t & k1,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int64_t & IH,
|
||||
constant int64_t & IW,
|
||||
constant int64_t & OH,
|
||||
constant int64_t & OW,
|
||||
constant int64_t & parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = IH * IW;
|
||||
const int O_HW = OH * OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
|
||||
float res = -INFINITY;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
res = MAX(res, i_ptr[i * IW + j]);
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_2d_avg_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int32_t & k0,
|
||||
constant int32_t & k1,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int64_t & IH,
|
||||
constant int64_t & IW,
|
||||
constant int64_t & OH,
|
||||
constant int64_t & OW,
|
||||
constant int64_t & parallel_elements,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = IH * IW;
|
||||
const int O_HW = OH * OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / OW;
|
||||
const int cur_ow = idx % O_HW % OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * s1 - p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(IH, start_h + k1);
|
||||
const int start_w = cur_ow * s0 - p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(IW, start_w + k0);
|
||||
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
||||
const float scale = 1. / (k0 * k1);
|
||||
|
||||
float res = 0;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
float cur = i_ptr[i * IW + j];
|
||||
res += cur * scale;
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * OW + cur_ow] = res;
|
||||
}
|
||||
|
|
|
@ -57,8 +57,9 @@ struct socket_t {
|
|||
}
|
||||
};
|
||||
|
||||
// ggml_tensor is serialized into rpc_tensor
|
||||
// all RPC structures must be packed
|
||||
#pragma pack(push, 1)
|
||||
// ggml_tensor is serialized into rpc_tensor
|
||||
struct rpc_tensor {
|
||||
uint64_t id;
|
||||
uint32_t type;
|
||||
|
@ -76,7 +77,6 @@ struct rpc_tensor {
|
|||
|
||||
char padding[4];
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
|
||||
|
||||
|
@ -96,6 +96,65 @@ enum rpc_cmd {
|
|||
RPC_CMD_COUNT,
|
||||
};
|
||||
|
||||
struct rpc_msg_alloc_buffer_req {
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct rpc_msg_alloc_buffer_rsp {
|
||||
uint64_t remote_ptr;
|
||||
uint64_t remote_size;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_alignment_rsp {
|
||||
uint64_t alignment;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_max_size_rsp {
|
||||
uint64_t max_size;
|
||||
};
|
||||
|
||||
struct rpc_msg_buffer_get_base_req {
|
||||
uint64_t remote_ptr;
|
||||
};
|
||||
|
||||
struct rpc_msg_buffer_get_base_rsp {
|
||||
uint64_t base_ptr;
|
||||
};
|
||||
|
||||
struct rpc_msg_free_buffer_req {
|
||||
uint64_t remote_ptr;
|
||||
};
|
||||
|
||||
struct rpc_msg_buffer_clear_req {
|
||||
uint64_t remote_ptr;
|
||||
uint8_t value;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_tensor_req {
|
||||
rpc_tensor tensor;
|
||||
uint64_t offset;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
struct rpc_msg_copy_tensor_req {
|
||||
rpc_tensor src;
|
||||
rpc_tensor dst;
|
||||
};
|
||||
|
||||
struct rpc_msg_copy_tensor_rsp {
|
||||
uint8_t result;
|
||||
};
|
||||
|
||||
struct rpc_msg_graph_compute_rsp {
|
||||
uint8_t result;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_device_memory_rsp {
|
||||
uint64_t free_mem;
|
||||
uint64_t total_mem;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
// RPC data structures
|
||||
|
||||
static ggml_guid_t ggml_backend_rpc_guid() {
|
||||
|
@ -240,6 +299,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
|
||||
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
|
||||
return false;
|
||||
}
|
||||
return send_data(sockfd, msg, msg_size);
|
||||
}
|
||||
|
||||
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
|
||||
uint64_t size;
|
||||
if (!recv_data(sockfd, &size, sizeof(size))) {
|
||||
return false;
|
||||
}
|
||||
if (size != msg_size) {
|
||||
return false;
|
||||
}
|
||||
return recv_data(sockfd, msg, msg_size);
|
||||
}
|
||||
|
||||
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
|
||||
uint64_t size;
|
||||
if (!recv_data(sockfd, &size, sizeof(size))) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
input.resize(size);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
|
||||
return false;
|
||||
}
|
||||
return recv_data(sockfd, input.data(), size);
|
||||
}
|
||||
|
||||
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
||||
size_t pos = endpoint.find(':');
|
||||
if (pos == std::string::npos) {
|
||||
|
@ -252,28 +343,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
|
|||
|
||||
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
||||
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
|
||||
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
|
||||
uint8_t cmd_byte = cmd;
|
||||
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
|
||||
return false;
|
||||
}
|
||||
uint64_t input_size = input.size();
|
||||
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
|
||||
return false;
|
||||
}
|
||||
if (!send_data(sock->fd, input.data(), input.size())) {
|
||||
if (!send_data(sock->fd, input, input_size)) {
|
||||
return false;
|
||||
}
|
||||
uint64_t output_size;
|
||||
if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
|
||||
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
|
||||
// even if we do, we can skip sending output_size from the server for commands with known output size
|
||||
uint64_t out_size;
|
||||
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
|
||||
return false;
|
||||
}
|
||||
if (output_size == 0) {
|
||||
output.clear();
|
||||
return true;
|
||||
if (out_size != output_size) {
|
||||
return false;
|
||||
}
|
||||
output.resize(output_size);
|
||||
if (!recv_data(sock->fd, output.data(), output_size)) {
|
||||
if (!recv_data(sock->fd, output, output_size)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -326,14 +416,9 @@ static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffe
|
|||
|
||||
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
// input serialization format: | remote_ptr (8 bytes) |
|
||||
std::vector<uint8_t> input(sizeof(uint64_t), 0);
|
||||
uint64_t remote_ptr = ctx->remote_ptr;
|
||||
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
|
||||
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.empty());
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
|
@ -342,20 +427,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|||
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
||||
return ctx->base_cache[buffer];
|
||||
}
|
||||
// input serialization format: | remote_ptr (8 bytes) |
|
||||
std::vector<uint8_t> input(sizeof(uint64_t), 0);
|
||||
uint64_t remote_ptr = ctx->remote_ptr;
|
||||
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
|
||||
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
||||
rpc_msg_buffer_get_base_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||
// output serialization format: | base_ptr (8 bytes) |
|
||||
uint64_t base_ptr;
|
||||
memcpy(&base_ptr, output.data(), sizeof(base_ptr));
|
||||
void * base = reinterpret_cast<void *>(base_ptr);
|
||||
ctx->base_cache[buffer] = base;
|
||||
return base;
|
||||
void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
||||
ctx->base_cache[buffer] = base_ptr;
|
||||
return base_ptr;
|
||||
}
|
||||
|
||||
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||
|
@ -405,26 +483,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|||
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
|
||||
GGML_ASSERT(status);
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
// input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
|
||||
int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
|
||||
std::vector<uint8_t> input(input_size, 0);
|
||||
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
||||
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
|
||||
rpc_msg_get_tensor_req request;
|
||||
request.tensor = serialize_tensor(tensor);
|
||||
request.offset = offset;
|
||||
request.size = size;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == size);
|
||||
// output serialization format: | data (size bytes) |
|
||||
memcpy(data, output.data(), size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
|
@ -437,30 +507,19 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
|
|||
return false;
|
||||
}
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
// input serialization format: | rpc_tensor src | rpc_tensor dst |
|
||||
int input_size = 2*sizeof(rpc_tensor);
|
||||
std::vector<uint8_t> input(input_size, 0);
|
||||
rpc_tensor rpc_src = serialize_tensor(src);
|
||||
rpc_tensor rpc_dst = serialize_tensor(dst);
|
||||
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
||||
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
|
||||
rpc_msg_copy_tensor_req request;
|
||||
request.src = serialize_tensor(src);
|
||||
request.dst = serialize_tensor(dst);
|
||||
rpc_msg_copy_tensor_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
// output serialization format: | result (1 byte) |
|
||||
GGML_ASSERT(output.size() == 1);
|
||||
return output[0];
|
||||
return response.result;
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
// serialization format: | bufptr (8 bytes) | value (1 byte) |
|
||||
int input_size = sizeof(uint64_t) + sizeof(uint8_t);
|
||||
std::vector<uint8_t> input(input_size, 0);
|
||||
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
||||
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
|
||||
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
|
||||
GGML_ASSERT(status);
|
||||
}
|
||||
|
||||
|
@ -484,25 +543,16 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
|
|||
|
||||
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||
// input serialization format: | size (8 bytes) |
|
||||
int input_size = sizeof(uint64_t);
|
||||
std::vector<uint8_t> input(input_size, 0);
|
||||
memcpy(input.data(), &size, sizeof(size));
|
||||
std::vector<uint8_t> output;
|
||||
rpc_msg_alloc_buffer_req request = {size};
|
||||
rpc_msg_alloc_buffer_rsp response;
|
||||
auto sock = get_socket(buft_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
||||
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
||||
uint64_t remote_ptr;
|
||||
memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
|
||||
size_t remote_size;
|
||||
memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
|
||||
if (remote_ptr != 0) {
|
||||
if (response.remote_ptr != 0) {
|
||||
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
||||
ggml_backend_rpc_buffer_interface,
|
||||
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
||||
remote_size);
|
||||
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
||||
response.remote_size);
|
||||
return buffer;
|
||||
} else {
|
||||
return nullptr;
|
||||
|
@ -510,16 +560,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|||
}
|
||||
|
||||
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
||||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
|
||||
rpc_msg_get_alignment_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||
// output serialization format: | alignment (8 bytes) |
|
||||
uint64_t alignment;
|
||||
memcpy(&alignment, output.data(), sizeof(alignment));
|
||||
return alignment;
|
||||
return response.alignment;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
|
@ -528,16 +572,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
|
|||
}
|
||||
|
||||
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
||||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
|
||||
rpc_msg_get_max_size_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||
// output serialization format: | max_size (8 bytes) |
|
||||
uint64_t max_size;
|
||||
memcpy(&max_size, output.data(), sizeof(max_size));
|
||||
return max_size;
|
||||
return response.max_size;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
|
@ -622,12 +660,11 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
std::vector<uint8_t> input;
|
||||
serialize_graph(cgraph, input);
|
||||
std::vector<uint8_t> output;
|
||||
rpc_msg_graph_compute_rsp response;
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == 1);
|
||||
return (enum ggml_status)output[0];
|
||||
return (enum ggml_status)response.result;
|
||||
}
|
||||
|
||||
static ggml_backend_i ggml_backend_rpc_interface = {
|
||||
|
@ -702,19 +739,11 @@ GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|||
}
|
||||
|
||||
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
||||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
|
||||
rpc_msg_get_device_memory_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||
uint64_t free_mem;
|
||||
memcpy(&free_mem, output.data(), sizeof(free_mem));
|
||||
uint64_t total_mem;
|
||||
memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
|
||||
*free = free_mem;
|
||||
*total = total_mem;
|
||||
*free = response.free_mem;
|
||||
*total = response.total_mem;
|
||||
}
|
||||
|
||||
GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
||||
|
@ -734,16 +763,16 @@ public:
|
|||
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
||||
~rpc_server();
|
||||
|
||||
bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||
void get_alignment(std::vector<uint8_t> & output);
|
||||
void get_max_size(std::vector<uint8_t> & output);
|
||||
bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||
bool free_buffer(const std::vector<uint8_t> & input);
|
||||
bool buffer_clear(const std::vector<uint8_t> & input);
|
||||
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
||||
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
||||
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
||||
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
||||
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
||||
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
||||
bool set_tensor(const std::vector<uint8_t> & input);
|
||||
bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
||||
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
||||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||
|
||||
private:
|
||||
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
||||
|
@ -757,80 +786,50 @@ private:
|
|||
std::unordered_set<ggml_backend_buffer_t> buffers;
|
||||
};
|
||||
|
||||
bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||
// input serialization format: | size (8 bytes) |
|
||||
if (input.size() != sizeof(uint64_t)) {
|
||||
return false;
|
||||
}
|
||||
uint64_t size;
|
||||
memcpy(&size, input.data(), sizeof(size));
|
||||
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
|
||||
uint64_t remote_ptr = 0;
|
||||
uint64_t remote_size = 0;
|
||||
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
||||
response.remote_ptr = 0;
|
||||
response.remote_size = 0;
|
||||
if (buffer != nullptr) {
|
||||
remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
||||
remote_size = buffer->size;
|
||||
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
|
||||
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
||||
response.remote_size = buffer->size;
|
||||
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
|
||||
buffers.insert(buffer);
|
||||
} else {
|
||||
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
|
||||
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
|
||||
}
|
||||
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
||||
output.resize(2*sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
|
||||
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
|
||||
return true;
|
||||
}
|
||||
|
||||
void rpc_server::get_alignment(std::vector<uint8_t> & output) {
|
||||
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
||||
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
||||
// output serialization format: | alignment (8 bytes) |
|
||||
output.resize(sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &alignment, sizeof(alignment));
|
||||
response.alignment = alignment;
|
||||
}
|
||||
|
||||
void rpc_server::get_max_size(std::vector<uint8_t> & output) {
|
||||
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
||||
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
||||
// output serialization format: | max_size (8 bytes) |
|
||||
output.resize(sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &max_size, sizeof(max_size));
|
||||
response.max_size = max_size;
|
||||
}
|
||||
|
||||
bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||
// input serialization format: | remote_ptr (8 bytes) |
|
||||
if (input.size() != sizeof(uint64_t)) {
|
||||
return false;
|
||||
}
|
||||
uint64_t remote_ptr;
|
||||
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
|
||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
||||
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
||||
if (buffers.find(buffer) == buffers.end()) {
|
||||
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
||||
return false;
|
||||
}
|
||||
void * base = ggml_backend_buffer_get_base(buffer);
|
||||
// output serialization format: | base_ptr (8 bytes) |
|
||||
uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
|
||||
output.resize(sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &base_ptr, sizeof(base_ptr));
|
||||
response.base_ptr = reinterpret_cast<uint64_t>(base);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
|
||||
// input serialization format: | remote_ptr (8 bytes) |
|
||||
if (input.size() != sizeof(uint64_t)) {
|
||||
return false;
|
||||
}
|
||||
uint64_t remote_ptr;
|
||||
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
|
||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
||||
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
|
||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
|
||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
||||
if (buffers.find(buffer) == buffers.end()) {
|
||||
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
||||
return false;
|
||||
|
@ -840,22 +839,14 @@ bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
|
||||
// input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
|
||||
if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
|
||||
return false;
|
||||
}
|
||||
uint64_t remote_ptr;
|
||||
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
|
||||
uint8_t value;
|
||||
memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
|
||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
|
||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
|
||||
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
|
||||
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
|
||||
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
|
||||
if (buffers.find(buffer) == buffers.end()) {
|
||||
GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_clear(buffer, value);
|
||||
ggml_backend_buffer_clear(buffer, request.value);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -930,74 +921,55 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||
// serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
|
||||
if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
|
||||
return false;
|
||||
}
|
||||
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
||||
uint64_t offset;
|
||||
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
||||
uint64_t size;
|
||||
memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
|
||||
|
||||
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
|
||||
struct ggml_init_params params {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
||||
if (tensor == nullptr) {
|
||||
GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
|
||||
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
|
||||
|
||||
// sanitize tensor->data
|
||||
{
|
||||
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
||||
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
||||
|
||||
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
||||
if (request.tensor.data + request.offset < p0 ||
|
||||
request.tensor.data + request.offset >= p1 ||
|
||||
request.size > (p1 - request.tensor.data - request.offset)) {
|
||||
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// output serialization format: | data (size bytes) |
|
||||
output.resize(size, 0);
|
||||
ggml_backend_tensor_get(tensor, output.data(), offset, size);
|
||||
response.resize(request.size, 0);
|
||||
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||
// serialization format: | rpc_tensor src | rpc_tensor dst |
|
||||
if (input.size() != 2*sizeof(rpc_tensor)) {
|
||||
return false;
|
||||
}
|
||||
const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
|
||||
const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
|
||||
|
||||
bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
|
||||
struct ggml_init_params params {
|
||||
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
|
||||
ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
|
||||
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
|
||||
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
|
||||
if (src == nullptr || dst == nullptr) {
|
||||
GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
|
||||
bool result = ggml_backend_buffer_copy_tensor(src, dst);
|
||||
// output serialization format: | result (1 byte) |
|
||||
output.resize(1, 0);
|
||||
output[0] = result;
|
||||
response.result = ggml_backend_buffer_copy_tensor(src, dst);
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
@ -1026,7 +998,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|||
return result;
|
||||
}
|
||||
|
||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
||||
// serialization format:
|
||||
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||
if (input.size() < sizeof(uint32_t)) {
|
||||
|
@ -1066,9 +1038,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
|
|||
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
|
||||
}
|
||||
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
||||
// output serialization format: | status (1 byte) |
|
||||
output.resize(1, 0);
|
||||
output[0] = status;
|
||||
response.result = status;
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
@ -1091,85 +1061,153 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|||
fprintf(stderr, "Unknown command: %d\n", cmd);
|
||||
break;
|
||||
}
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
uint64_t input_size;
|
||||
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
||||
break;
|
||||
}
|
||||
try {
|
||||
input.resize(input_size);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
|
||||
break;
|
||||
}
|
||||
if (!recv_data(sockfd, input.data(), input_size)) {
|
||||
break;
|
||||
}
|
||||
bool ok = true;
|
||||
switch (cmd) {
|
||||
case RPC_CMD_ALLOC_BUFFER: {
|
||||
ok = server.alloc_buffer(input, output);
|
||||
rpc_msg_alloc_buffer_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_alloc_buffer_rsp response;
|
||||
server.alloc_buffer(request, response);
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GET_ALIGNMENT: {
|
||||
server.get_alignment(output);
|
||||
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_get_alignment_rsp response;
|
||||
server.get_alignment(response);
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GET_MAX_SIZE: {
|
||||
server.get_max_size(output);
|
||||
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_get_max_size_rsp response;
|
||||
server.get_max_size(response);
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_BUFFER_GET_BASE: {
|
||||
ok = server.buffer_get_base(input, output);
|
||||
rpc_msg_buffer_get_base_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_buffer_get_base_rsp response;
|
||||
if (!server.buffer_get_base(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_FREE_BUFFER: {
|
||||
ok = server.free_buffer(input);
|
||||
rpc_msg_free_buffer_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
if (!server.free_buffer(request)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, nullptr, 0)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_BUFFER_CLEAR: {
|
||||
ok = server.buffer_clear(input);
|
||||
rpc_msg_buffer_clear_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
if (!server.buffer_clear(request)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, nullptr, 0)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_SET_TENSOR: {
|
||||
ok = server.set_tensor(input);
|
||||
std::vector<uint8_t> input;
|
||||
if (!recv_msg(sockfd, input)) {
|
||||
return;
|
||||
}
|
||||
if (!server.set_tensor(input)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, nullptr, 0)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GET_TENSOR: {
|
||||
ok = server.get_tensor(input, output);
|
||||
rpc_msg_get_tensor_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
std::vector<uint8_t> response;
|
||||
if (!server.get_tensor(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, response.data(), response.size())) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_COPY_TENSOR: {
|
||||
ok = server.copy_tensor(input, output);
|
||||
rpc_msg_copy_tensor_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_copy_tensor_rsp response;
|
||||
if (!server.copy_tensor(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GRAPH_COMPUTE: {
|
||||
ok = server.graph_compute(input, output);
|
||||
std::vector<uint8_t> input;
|
||||
if (!recv_msg(sockfd, input)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_graph_compute_rsp response;
|
||||
if (!server.graph_compute(input, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GET_DEVICE_MEMORY: {
|
||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||
output.resize(2*sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
||||
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
|
||||
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_get_device_memory_rsp response;
|
||||
response.free_mem = free_mem;
|
||||
response.total_mem = total_mem;
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
fprintf(stderr, "Unknown command: %d\n", cmd);
|
||||
ok = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (!ok) {
|
||||
break;
|
||||
}
|
||||
uint64_t output_size = output.size();
|
||||
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
|
||||
break;
|
||||
}
|
||||
if (!send_data(sockfd, output.data(), output_size)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,6 @@
|
|||
#include "mmvq.hpp"
|
||||
#include "vecdotq.hpp"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
||||
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
|
||||
|
@ -13,7 +13,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
@ -37,7 +38,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -61,7 +62,8 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
@ -85,7 +87,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -109,8 +111,8 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
|
@ -133,7 +135,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -157,8 +159,8 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
|
@ -181,7 +183,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -205,8 +207,8 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
|
@ -229,7 +231,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -253,8 +255,8 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
|
@ -277,7 +279,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -301,8 +303,8 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
|
@ -325,7 +327,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -349,8 +351,8 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
|
@ -373,7 +375,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -397,8 +399,8 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
|
@ -421,7 +423,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -446,8 +448,8 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
|
||||
assert(blocks_per_warp>0);
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
|
@ -470,7 +472,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -487,7 +489,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -495,7 +497,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -511,7 +513,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK4_1 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -519,7 +521,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
||||
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -535,7 +537,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK5_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -543,7 +545,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
||||
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -559,7 +561,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK5_1 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -567,7 +569,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
||||
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -583,7 +585,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK8_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -591,7 +593,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
||||
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -607,7 +609,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -615,7 +617,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
||||
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -631,7 +633,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -639,7 +641,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
||||
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -655,7 +657,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -663,7 +665,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
||||
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -679,7 +681,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -687,7 +689,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
||||
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -703,7 +705,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -711,7 +713,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
||||
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
|
@ -728,13 +730,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -749,7 +751,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -759,7 +761,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -774,7 +776,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -784,7 +786,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -799,7 +801,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -809,7 +811,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -824,7 +826,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -833,7 +835,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -848,7 +850,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
@ -858,7 +860,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
|||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -873,13 +875,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -894,14 +896,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK4_NL == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
@ -916,14 +918,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
|||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
|
|
|
@ -324,8 +324,9 @@ struct ggml_logger_state {
|
|||
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
|
||||
|
||||
static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
|
||||
if (format == NULL)
|
||||
if (format == NULL) {
|
||||
return;
|
||||
}
|
||||
va_list args_copy;
|
||||
va_copy(args_copy, args);
|
||||
char buffer[128];
|
||||
|
@ -3483,7 +3484,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
|||
|
||||
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
||||
size_t nbytes;
|
||||
size_t blck_size = ggml_blck_size(tensor->type);
|
||||
const size_t blck_size = ggml_blck_size(tensor->type);
|
||||
if (blck_size == 1) {
|
||||
nbytes = ggml_type_size(tensor->type);
|
||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
|
@ -3872,10 +3873,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||
},
|
||||
};
|
||||
|
||||
for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
|
||||
g_state.contexts[i].used = false;
|
||||
}
|
||||
|
||||
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||
|
||||
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
||||
|
@ -15778,6 +15775,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
||||
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
||||
|
||||
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
|
||||
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
|
@ -23356,6 +23356,14 @@ int ggml_cpu_has_avx512_bf16(void) {
|
|||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_amx_int8(void) {
|
||||
#if defined(__AMX_INT8__)
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_fma(void) {
|
||||
#if defined(__FMA__)
|
||||
return 1;
|
||||
|
|
|
@ -1489,7 +1489,7 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr,};
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to eval image\n", __func__);
|
||||
return false;
|
||||
|
@ -2004,7 +2004,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
//determine mem per token
|
||||
std::vector<int> tmp = {1, 2, 3, 4};
|
||||
llama_kv_cache_clear(llama_ctx_v4);
|
||||
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
|
||||
auto er = llama_decode(llama_ctx_v4, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if(er!=0)
|
||||
{
|
||||
printf("\nLLAMA EVAL returned nonzero: %d\n",er);
|
||||
|
@ -3061,7 +3061,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
else if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize, n_past, 0))==0);
|
||||
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize))==0);
|
||||
}
|
||||
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
|
@ -3432,7 +3432,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
if(i>0 && sepsize>0)
|
||||
{
|
||||
//add a separator between each image
|
||||
auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize, n_past, 0));
|
||||
auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize));
|
||||
if(evr!=0)
|
||||
{
|
||||
printf("\nError when appending llava separator: %d\n",evr);
|
||||
|
|
|
@ -217,6 +217,7 @@ extern "C" {
|
|||
|
||||
typedef struct llama_token_data_array {
|
||||
// TODO: consider SoA
|
||||
// NOTE: this pointer can be modified by the samplers
|
||||
llama_token_data * data;
|
||||
size_t size;
|
||||
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
||||
|
@ -232,8 +233,11 @@ extern "C" {
|
|||
// - token : the token ids of the input (used when embd is NULL)
|
||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
// - pos : the positions of the respective token in the sequence
|
||||
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
||||
// - seq_id : the sequence to which the respective token belongs
|
||||
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
// (if set to NULL, only the logits for last token will be returned)
|
||||
//
|
||||
typedef struct llama_batch {
|
||||
int32_t n_tokens;
|
||||
|
@ -244,15 +248,6 @@ extern "C" {
|
|||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
|
||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||
// for future-proof code, use the above fields instead and ignore everything below
|
||||
//
|
||||
// pos[i] = all_pos_0 + i*all_pos_1
|
||||
//
|
||||
llama_pos all_pos_0; // used if pos == NULL
|
||||
llama_pos all_pos_1; // used if pos == NULL
|
||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
@ -778,15 +773,15 @@ extern "C" {
|
|||
// Decoding
|
||||
//
|
||||
|
||||
// Return batch for single sequence of tokens starting at pos_0
|
||||
// Return batch for single sequence of tokens
|
||||
// The sequence ID will be fixed to 0
|
||||
// The position of the tokens will be tracked automatically by llama_decode
|
||||
//
|
||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||
//
|
||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
llama_pos pos_0,
|
||||
llama_seq_id seq_id);
|
||||
int32_t n_tokens);
|
||||
|
||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||
// Each token can be assigned up to n_seq_max sequence ids
|
||||
|
@ -955,12 +950,6 @@ extern "C" {
|
|||
int32_t lstrip,
|
||||
bool special);
|
||||
|
||||
// check if token0 is contained as a prefix in token1
|
||||
LLAMA_API bool llama_token_is_prefix(
|
||||
const struct llama_model * model,
|
||||
llama_token token0,
|
||||
llama_token token1);
|
||||
|
||||
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
|
||||
/// @param text The char pointer must be large enough to hold the resulting text.
|
||||
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
|
||||
|
@ -1088,7 +1077,8 @@ extern "C" {
|
|||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
||||
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
|
||||
"will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
||||
|
@ -1104,6 +1094,8 @@ extern "C" {
|
|||
|
||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
||||
|
||||
/// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
||||
|
||||
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
||||
|
|
|
@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
|
|||
}
|
||||
*/
|
||||
|
||||
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
|
||||
if (temp <= 0.0f) {
|
||||
// find the token with the highest logit and set the rest to -inf
|
||||
size_t max_i = 0;
|
||||
float max_l = cur_p->data[0].logit;
|
||||
|
||||
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i ].logit > max_l) {
|
||||
cur_p->data[max_i].logit = -INFINITY;
|
||||
max_i = i;
|
||||
max_l = cur_p->data[i].logit;
|
||||
} else {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].logit /= temp;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(cur_p->size > 0);
|
||||
|
||||
|
@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
|
|||
|
||||
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||
}
|
||||
|
||||
|
@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
|
|||
|
||||
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].logit /= ctx->temp;
|
||||
}
|
||||
|
||||
llama_sampler_temp_impl(cur_p, ctx->temp);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
|
||||
|
@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||
if (ctx->delta > 0) {
|
||||
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
|
||||
const float max_temp = ctx->temp + ctx->delta;
|
||||
|
||||
float exponent_val = ctx->exponent;
|
||||
|
||||
// no need to do anything if there is only one (or zero) candidates
|
||||
|
@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||
#endif
|
||||
|
||||
// Apply the dynamically calculated temperature scaling
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].logit /= dyn_temp;
|
||||
}
|
||||
llama_sampler_temp_impl(cur_p, dyn_temp);
|
||||
|
||||
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||
const double max_l_double = cur_p->data[0].logit;
|
||||
|
@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
|
|||
}
|
||||
#endif
|
||||
} else {
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].logit /= ctx->temp;
|
||||
}
|
||||
llama_sampler_temp_impl(cur_p, ctx->temp);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1745,6 +1768,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|||
|
||||
struct llama_sampler_infill {
|
||||
const struct llama_vocab * vocab;
|
||||
|
||||
std::vector<char> buf0;
|
||||
std::vector<char> buf1;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
@ -1810,27 +1836,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
|||
size_t n_combined = 0; GGML_UNUSED(n_combined);
|
||||
|
||||
// combine tokens with common prefix
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
for (size_t j = 0; j < cur_p->size; ++j) {
|
||||
if (cur_p->data[i].logit == -INFINITY) {
|
||||
for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
|
||||
for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
|
||||
if (cur_p->data[i0].logit == -INFINITY) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (i == j || cur_p->data[j].logit == -INFINITY) {
|
||||
if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
|
||||
if (cur_p->data[i].p > cur_p->data[j].p) {
|
||||
cur_p->data[i].p += cur_p->data[j].p;
|
||||
cur_p->data[j].logit = -INFINITY;
|
||||
cur_p->data[j].p = 0.0f;
|
||||
} else {
|
||||
cur_p->data[j].p += cur_p->data[i].p;
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
cur_p->data[i].p = 0.0f;
|
||||
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||
if (len0 < 0) {
|
||||
ctx->buf0.resize(len0);
|
||||
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||
assert(len0 > 0);
|
||||
}
|
||||
|
||||
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||
if (len1 < 0) {
|
||||
ctx->buf1.resize(len1);
|
||||
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||
assert(len1 > 0);
|
||||
}
|
||||
|
||||
// token i0 is a prefix of token i1
|
||||
if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
|
||||
int dst = i0;
|
||||
int src = i1;
|
||||
|
||||
// merge into the token with higher probability
|
||||
if (cur_p->data[i1].p > cur_p->data[i0].p) {
|
||||
std::swap(dst, src);
|
||||
}
|
||||
|
||||
cur_p->data[dst].p += cur_p->data[src].p;
|
||||
cur_p->data[src].logit = -INFINITY;
|
||||
cur_p->data[src].p = 0.0f;
|
||||
|
||||
n_combined++;
|
||||
}
|
||||
}
|
||||
|
@ -1936,6 +1979,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
|
|||
/* .iface = */ &llama_sampler_infill_i,
|
||||
/* .ctx = */ new llama_sampler_infill {
|
||||
/* .vocab = */ &vocab,
|
||||
/* .buf0 = */ std::vector<char>(512),
|
||||
/* .buf1 = */ std::vector<char>(512),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -2118,23 +2118,6 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|||
return 0;
|
||||
}
|
||||
|
||||
bool llama_token_is_prefix_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token token0,
|
||||
llama_token token1) {
|
||||
char text_buf_0[128];
|
||||
char text_buf_1[128];
|
||||
|
||||
const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
|
||||
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
|
||||
|
||||
if (len0 <= 0 || len1 <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
|
||||
}
|
||||
|
||||
int32_t llama_detokenize_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const llama_token * tokens,
|
||||
|
|
488
src/llama.cpp
488
src/llama.cpp
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue