wstring convert for mpt

This commit is contained in:
Concedo 2023-06-24 11:43:42 +08:00
parent 6d718525c4
commit 0485fa65a2
5 changed files with 30 additions and 7 deletions

View file

@ -86,6 +86,16 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
fin.read((char *) buf.data(), len);
word.assign(buf.data(), len);
// Convert token from utf-8
std::wstring word_multibytes = convert_to_wstring(word);
if(word_multibytes!=L"")
{
word.resize(word_multibytes.size());
for (int w = 0; w < word_multibytes.size(); w++) {
word[w] = uint8_t(word_multibytes[w]);
}
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
@ -123,8 +133,8 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight
ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight
ctx_size += (n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16)); // memory_k
ctx_size += (n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16)); // memory_v
ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_k
ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_v
ctx_size += (6 + 6 * n_layer) * 512; // object overhead