mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-14 10:59:41 +00:00
model : add PLaMo-2 support (#14560)
* Add PLaMo-2 model using hybrid memory module * Fix z shape * Add cmath to include from llama-vocab.h * Explicitly dequantize normalization weights before RoPE apply * Revert unnecessary cast because the problem can be solved by excluding attn_k, attn_q when quantizing * Use ATTN_K/Q_NORM for k,q weights to prevent quantization * Remove SSM_BCDT that is not used from anywhere * Do not duplicate embedding weights for output.weight * Fix tokenizer encoding problem for multibyte strings * Apply suggestion from @CISC Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Use LLM_FFN_SWIGLU instead of splitting ffn_gate and ffn_up * Remove unnecessary part for Grouped Query Attention * Fix how to load special token id to gguf * Remove unused tensor mapping * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Remove llama_vocab_plamo2 class and replace it with llm_tokenizer_plamo2_session to follow the other tokenizer implementations * Update src/llama-vocab.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Fix plamo2 tokenizer session to prevent multiple calls of build() --------- Co-authored-by: Francis Couture-Harpin <git@compilade.net> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
cbc68be51d
commit
68e37a61a7
8 changed files with 1048 additions and 44 deletions
|
@ -11,6 +11,7 @@
|
|||
#include <cassert>
|
||||
#include <cctype>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
#include <forward_list>
|
||||
|
@ -1196,6 +1197,284 @@ private:
|
|||
const llm_tokenizer_rwkv & tokenizer;
|
||||
};
|
||||
|
||||
struct llm_tokenizer_plamo2 : llm_tokenizer {
|
||||
llm_tokenizer_plamo2(const llama_vocab & vocab) {
|
||||
build(vocab);
|
||||
}
|
||||
|
||||
void build(const llama_vocab & vocab) {
|
||||
// Reset internal structures
|
||||
tokens_.clear();
|
||||
bytes_.assign(256, 0);
|
||||
to_suffix_id_.clear();
|
||||
table_.clear();
|
||||
|
||||
// Build token list and byte mapping
|
||||
std::unordered_map<std::string, float> suffix_to_score;
|
||||
std::unordered_map<std::string, llama_token> token_to_id;
|
||||
|
||||
for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) {
|
||||
const auto & entry = vocab.get_token_data(token_id);
|
||||
tokens_.push_back(entry.text);
|
||||
token_to_id[entry.text] = static_cast<llama_token>(token_id);
|
||||
|
||||
// Handle byte tokens
|
||||
if (vocab.is_byte(token_id)) {
|
||||
if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') {
|
||||
std::string hex_str = entry.text.substr(3, 2);
|
||||
int byte_val = std::stoi(hex_str, nullptr, 16);
|
||||
bytes_[byte_val] = static_cast<llama_token>(token_id);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add token and all its suffixes to suffix_to_score
|
||||
suffix_to_score[entry.text] = entry.score;
|
||||
|
||||
// Extract suffixes character by character (UTF-8 aware)
|
||||
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(entry.text);
|
||||
for (size_t i = 1; i < cpts.size(); ++i) {
|
||||
std::string suffix;
|
||||
for (size_t j = i; j < cpts.size(); ++j) {
|
||||
suffix += unicode_cpt_to_utf8(cpts[j]);
|
||||
}
|
||||
if (suffix_to_score.find(suffix) == suffix_to_score.end()) {
|
||||
suffix_to_score[suffix] = std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all byte tokens are set
|
||||
for (int i = 0; i < 256; ++i) {
|
||||
if (bytes_[i] == 0) {
|
||||
throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set");
|
||||
}
|
||||
}
|
||||
|
||||
// Build suffix list in lexicographical order of reversed strings
|
||||
std::vector<std::string> suffixes;
|
||||
for (const auto & pair : suffix_to_score) {
|
||||
suffixes.push_back(pair.first);
|
||||
}
|
||||
suffixes.push_back(""); // Empty suffix
|
||||
|
||||
std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) {
|
||||
std::string rev_a(a.rbegin(), a.rend());
|
||||
std::string rev_b(b.rbegin(), b.rend());
|
||||
return rev_a < rev_b;
|
||||
});
|
||||
|
||||
// Build suffix_to_id and to_suffix_id_
|
||||
std::unordered_map<std::string, int32_t> suffix_to_id;
|
||||
int32_t num_pieces = 0;
|
||||
|
||||
for (const auto & suffix : suffixes) {
|
||||
suffix_to_id[suffix] = num_pieces;
|
||||
if (!suffix.empty()) {
|
||||
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
|
||||
|
||||
std::string remaining;
|
||||
for (size_t i = 1; i < cpts.size(); ++i) {
|
||||
remaining += unicode_cpt_to_utf8(cpts[i]);
|
||||
}
|
||||
|
||||
int64_t piece_code = (static_cast<int64_t>(cpts[0]) << 32) | suffix_to_id[remaining];
|
||||
to_suffix_id_[piece_code] = num_pieces;
|
||||
|
||||
// Count number of pieces for this suffix
|
||||
int32_t pieces_for_suffix = 1; // sentinel row
|
||||
for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
|
||||
std::string piece;
|
||||
for (int32_t i = 0; i < piece_length; ++i) {
|
||||
piece += unicode_cpt_to_utf8(cpts[i]);
|
||||
}
|
||||
if (suffix_to_score.find(piece) != suffix_to_score.end()) {
|
||||
pieces_for_suffix++;
|
||||
}
|
||||
}
|
||||
num_pieces += pieces_for_suffix;
|
||||
} else {
|
||||
num_pieces++; // Empty suffix contributes one piece (sentinel row)
|
||||
}
|
||||
}
|
||||
|
||||
// Build flattened table
|
||||
table_.resize(num_pieces, std::vector<int32_t>(4, 0));
|
||||
int32_t table_idx = 0;
|
||||
|
||||
for (const auto & suffix : suffixes) {
|
||||
// Add all prefixes of the suffix to the table (in decreasing order of length)
|
||||
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
|
||||
for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
|
||||
std::string piece;
|
||||
for (int32_t i = 0; i < piece_length; ++i) {
|
||||
piece += unicode_cpt_to_utf8(cpts[i]);
|
||||
}
|
||||
|
||||
auto score_it = suffix_to_score.find(piece);
|
||||
if (score_it == suffix_to_score.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
table_[table_idx][TABLE_PIECE_LENGTH] = piece_length;
|
||||
auto token_it = token_to_id.find(piece);
|
||||
table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1;
|
||||
|
||||
float score = score_it->second;
|
||||
table_[table_idx][TABLE_SCORE] = std::isfinite(score) ?
|
||||
static_cast<int32_t>(std::round(score * 1e4)) : INVALID_SCORE;
|
||||
table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece];
|
||||
|
||||
table_idx++;
|
||||
}
|
||||
|
||||
// Add sentinel row
|
||||
table_[table_idx][TABLE_PIECE_LENGTH] = 1;
|
||||
table_[table_idx][TABLE_TOKEN_ID] = -1;
|
||||
table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE;
|
||||
table_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_token> encode(const std::string & text) const {
|
||||
std::vector<uint32_t> unicode_data = unicode_cpts_from_utf8(text);
|
||||
// Skip the first code point if it is a BOM (Byte Order Mark)
|
||||
if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) {
|
||||
unicode_data.erase(unicode_data.begin());
|
||||
}
|
||||
|
||||
if (unicode_data.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const size_t data_len = unicode_data.size();
|
||||
|
||||
// Initialize scores array (dynamic programming)
|
||||
std::vector<int64_t> scores(data_len + 1, static_cast<int64_t>(1) << 60);
|
||||
scores[data_len] = 0;
|
||||
|
||||
// Path array to track best tokenization
|
||||
std::vector<std::vector<int32_t>> path(data_len + 1, std::vector<int32_t>(3, 0));
|
||||
|
||||
int32_t suffix_id = 0;
|
||||
|
||||
// Process from end to beginning
|
||||
for (int i = static_cast<int>(data_len) - 1; i >= 0; --i) {
|
||||
uint32_t c = unicode_data[i];
|
||||
|
||||
// Find next suffix ID
|
||||
for (size_t p = suffix_id; p < table_.size(); ++p) {
|
||||
int64_t piece_code = (static_cast<int64_t>(c) << 32) | table_[p][TABLE_PIECE_ID];
|
||||
auto it = to_suffix_id_.find(piece_code);
|
||||
suffix_id = (it != to_suffix_id_.end()) ? it->second : 0;
|
||||
|
||||
if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Update best path
|
||||
for (size_t p = suffix_id; p < table_.size(); ++p) {
|
||||
int32_t score = table_[p][TABLE_SCORE];
|
||||
if (score > INVALID_SCORE) {
|
||||
int32_t piece_length = table_[p][TABLE_PIECE_LENGTH];
|
||||
int64_t s = scores[i + piece_length] - score;
|
||||
|
||||
if (s < scores[i]) {
|
||||
scores[i] = s;
|
||||
path[i][PATH_TOKEN_LENGTH] = piece_length;
|
||||
path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID];
|
||||
path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1;
|
||||
|
||||
if (score == UNKNOWN_SCORE) {
|
||||
// Add UTF-8 byte count
|
||||
path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (score == UNKNOWN_SCORE) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode the best path
|
||||
std::vector<llama_token> token_ids;
|
||||
token_ids.reserve(path[0][PATH_NUM_TOKENS]);
|
||||
|
||||
int pos = 0;
|
||||
while (pos < static_cast<int>(data_len)) {
|
||||
if (path[pos][PATH_TOKEN_ID] >= 0) {
|
||||
token_ids.push_back(path[pos][PATH_TOKEN_ID]);
|
||||
} else {
|
||||
// Fall back to byte tokens
|
||||
uint32_t c = unicode_data[pos];
|
||||
int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
|
||||
|
||||
for (int i = 0; i < s; ++i) {
|
||||
uint8_t b;
|
||||
if (s == 1) {
|
||||
b = c;
|
||||
} else {
|
||||
if (i == 0) {
|
||||
b = (0xF00 >> s) & 0xFF;
|
||||
} else {
|
||||
b = 0x80;
|
||||
}
|
||||
}
|
||||
token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]);
|
||||
}
|
||||
}
|
||||
|
||||
assert(path[pos][PATH_TOKEN_LENGTH] > 0);
|
||||
pos += path[pos][PATH_TOKEN_LENGTH];
|
||||
}
|
||||
|
||||
return token_ids;
|
||||
}
|
||||
private:
|
||||
// Constants for table structure
|
||||
static constexpr int32_t TABLE_PIECE_LENGTH = 0;
|
||||
static constexpr int32_t TABLE_TOKEN_ID = 1;
|
||||
static constexpr int32_t TABLE_SCORE = 2;
|
||||
static constexpr int32_t TABLE_PIECE_ID = 3;
|
||||
|
||||
// Constants for path array
|
||||
static constexpr int32_t PATH_TOKEN_LENGTH = 0;
|
||||
static constexpr int32_t PATH_TOKEN_ID = 1;
|
||||
static constexpr int32_t PATH_NUM_TOKENS = 2;
|
||||
|
||||
// Score constants
|
||||
static constexpr int32_t INVALID_SCORE = -20000000;
|
||||
static constexpr int32_t UNKNOWN_SCORE = -10000000;
|
||||
|
||||
// List of tokens in the vocabulary
|
||||
std::vector<std::string> tokens_;
|
||||
|
||||
// Mapping from byte code point to token ID (for byte fallback)
|
||||
std::vector<llama_token> bytes_;
|
||||
|
||||
// Mapping from piece code to suffix ID
|
||||
std::unordered_map<int64_t, int32_t> to_suffix_id_;
|
||||
|
||||
// Flattened table representing the Trie structure
|
||||
// Each row contains: [piece_length, token_id, score, piece_id]
|
||||
std::vector<std::vector<int32_t>> table_;
|
||||
};
|
||||
|
||||
struct llm_tokenizer_plamo2_session {
|
||||
llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_token> & output) {
|
||||
std::vector<llama_token> tokens = tokenizer.encode(text);
|
||||
output.insert(output.end(), tokens.begin(), tokens.end());
|
||||
}
|
||||
|
||||
private:
|
||||
const llm_tokenizer_plamo2 & tokenizer;
|
||||
};
|
||||
|
||||
//
|
||||
// impl
|
||||
//
|
||||
|
@ -1499,6 +1778,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
special_unk_id = LLAMA_TOKEN_NULL;
|
||||
special_sep_id = LLAMA_TOKEN_NULL;
|
||||
special_pad_id = LLAMA_TOKEN_NULL;
|
||||
} else if (tokenizer_model == "plamo2") {
|
||||
type = LLAMA_VOCAB_TYPE_PLAMO2;
|
||||
|
||||
// PLaMo-2 default special tokens (these will be overridden by model config)
|
||||
special_bos_id = 1; // <|plamo:bos|>
|
||||
special_eos_id = 2; // <|plamo:eos|>
|
||||
special_unk_id = 0; // <|plamo:unk|>
|
||||
special_sep_id = LLAMA_TOKEN_NULL;
|
||||
special_pad_id = 3; // <|plamo:pad|>
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
||||
}
|
||||
|
@ -2145,13 +2434,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const {
|
|||
|
||||
std::string llama_vocab::impl::type_name() const{
|
||||
switch (type) {
|
||||
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
|
||||
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
|
||||
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
|
||||
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
|
||||
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
|
||||
case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
|
||||
default: return "unknown";
|
||||
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
|
||||
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
|
||||
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
|
||||
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
|
||||
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
|
||||
case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
|
||||
case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2234,6 +2524,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
|
|||
case LLAMA_VOCAB_TYPE_RWKV:
|
||||
tokenizer = std::make_unique<llm_tokenizer_rwkv>(vocab);
|
||||
break;
|
||||
case LLAMA_VOCAB_TYPE_PLAMO2:
|
||||
tokenizer = std::make_unique<llm_tokenizer_plamo2>(vocab);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("unsupported vocab type");
|
||||
}
|
||||
|
@ -2566,6 +2859,23 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
|
|||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
||||
#endif
|
||||
|
||||
session.tokenize(text, output);
|
||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_PLAMO2:
|
||||
{
|
||||
llm_tokenizer_plamo2_session session(*static_cast<const llm_tokenizer_plamo2 *>(tokenizer.get()));
|
||||
for (const auto & fragment : fragment_buffer) {
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
||||
#endif
|
||||
|
@ -2664,6 +2974,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
|
|||
memcpy(buf, result.data(), result.size());
|
||||
return (int)result.size();
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_PLAMO2: {
|
||||
// PLaMo-2 uses similar token handling as BPE/SPM
|
||||
if (vocab.is_byte(token)) {
|
||||
// Handle byte tokens like <0xXX>
|
||||
if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') {
|
||||
int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16);
|
||||
if (length < 1) {
|
||||
return -1;
|
||||
}
|
||||
buf[0] = static_cast<char>(hex_val);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Normal token - just copy the text
|
||||
std::string result = token_text;
|
||||
return _try_copy(result.data(), result.size());
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
@ -2908,6 +3236,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const {
|
|||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_PLAMO2: {
|
||||
// PLaMo-2 uses byte tokens in format <0xXX>
|
||||
char hex_str[8];
|
||||
snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
|
||||
return pimpl->token_to_id.at(hex_str);
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
@ -3385,4 +3719,3 @@ int32_t llama_detokenize(
|
|||
bool unparse_special) {
|
||||
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue