mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # tests/test-tokenizer-0.cpp # tests/test-tokenizer-random.py
This commit is contained in:
commit
5e43ebd151
3 changed files with 104 additions and 64 deletions
|
@ -373,6 +373,29 @@ class Model:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
|
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
|
||||||
|
|
||||||
|
def does_token_look_special(self, token: str | bytes) -> bool:
|
||||||
|
if isinstance(token, (bytes, bytearray)):
|
||||||
|
token_text = token.decode(encoding="utf-8")
|
||||||
|
elif isinstance(token, memoryview):
|
||||||
|
token_text = token.tobytes().decode(encoding="utf-8")
|
||||||
|
else:
|
||||||
|
token_text = token
|
||||||
|
|
||||||
|
# Some models mark some added tokens which ought to be control tokens as not special.
|
||||||
|
# (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2})
|
||||||
|
seems_special = token_text in (
|
||||||
|
"<pad>", # deepseek-coder
|
||||||
|
"<mask>", "<2mass>", "[@BOS@]", # gemma{,-2}
|
||||||
|
)
|
||||||
|
|
||||||
|
seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>"))
|
||||||
|
seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) # deepseek-coder
|
||||||
|
|
||||||
|
# TODO: should these be marked as UNUSED instead? (maybe not)
|
||||||
|
seems_special = seems_special or (token_text.startswith("<unused") and token_text.endswith(">")) # gemma{,-2}
|
||||||
|
|
||||||
|
return seems_special
|
||||||
|
|
||||||
# used for GPT-2 BPE and WordPiece vocabs
|
# used for GPT-2 BPE and WordPiece vocabs
|
||||||
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
|
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
|
||||||
tokens: list[str] = []
|
tokens: list[str] = []
|
||||||
|
@ -391,16 +414,18 @@ class Model:
|
||||||
for i in range(vocab_size):
|
for i in range(vocab_size):
|
||||||
if i not in reverse_vocab:
|
if i not in reverse_vocab:
|
||||||
tokens.append(f"[PAD{i}]")
|
tokens.append(f"[PAD{i}]")
|
||||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
toktypes.append(gguf.TokenType.UNUSED)
|
||||||
elif reverse_vocab[i] in added_vocab:
|
|
||||||
tokens.append(reverse_vocab[i])
|
|
||||||
if tokenizer.added_tokens_decoder[i].special:
|
|
||||||
toktypes.append(gguf.TokenType.CONTROL)
|
|
||||||
else:
|
|
||||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
|
||||||
else:
|
else:
|
||||||
tokens.append(reverse_vocab[i])
|
token: str = reverse_vocab[i]
|
||||||
toktypes.append(gguf.TokenType.NORMAL)
|
if token in added_vocab:
|
||||||
|
if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||||
|
toktypes.append(gguf.TokenType.CONTROL)
|
||||||
|
else:
|
||||||
|
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
|
||||||
|
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||||
|
else:
|
||||||
|
toktypes.append(gguf.TokenType.NORMAL)
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
return tokens, toktypes, tokpre
|
return tokens, toktypes, tokpre
|
||||||
|
|
||||||
|
@ -559,7 +584,7 @@ class Model:
|
||||||
for i in range(vocab_size):
|
for i in range(vocab_size):
|
||||||
if i not in reverse_vocab:
|
if i not in reverse_vocab:
|
||||||
tokens.append(f"[PAD{i}]")
|
tokens.append(f"[PAD{i}]")
|
||||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
toktypes.append(gguf.TokenType.UNUSED)
|
||||||
elif reverse_vocab[i] in added_vocab:
|
elif reverse_vocab[i] in added_vocab:
|
||||||
tokens.append(reverse_vocab[i])
|
tokens.append(reverse_vocab[i])
|
||||||
toktypes.append(gguf.TokenType.CONTROL)
|
toktypes.append(gguf.TokenType.CONTROL)
|
||||||
|
@ -609,7 +634,7 @@ class Model:
|
||||||
|
|
||||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||||
scores: list[float] = [-10000.0] * vocab_size
|
scores: list[float] = [-10000.0] * vocab_size
|
||||||
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
|
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||||
|
|
||||||
for token_id in range(tokenizer.vocab_size()):
|
for token_id in range(tokenizer.vocab_size()):
|
||||||
piece = tokenizer.IdToPiece(token_id)
|
piece = tokenizer.IdToPiece(token_id)
|
||||||
|
@ -644,6 +669,25 @@ class Model:
|
||||||
scores[token_id] = -1000.0
|
scores[token_id] = -1000.0
|
||||||
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
|
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
|
||||||
|
|
||||||
|
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
||||||
|
if tokenizer_config_file.is_file():
|
||||||
|
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
|
||||||
|
tokenizer_config_json = json.load(f)
|
||||||
|
added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {})
|
||||||
|
for token_id, token_data in added_tokens_decoder.items():
|
||||||
|
token_id = int(token_id)
|
||||||
|
token: str = token_data["content"]
|
||||||
|
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
|
||||||
|
assert tokens[token_id] == token.encode("utf-8")
|
||||||
|
if token_data.get("special") or self.does_token_look_special(token):
|
||||||
|
toktypes[token_id] = SentencePieceTokenTypes.CONTROL
|
||||||
|
else:
|
||||||
|
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
|
||||||
|
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
|
||||||
|
|
||||||
|
scores[token_id] = -1000.0
|
||||||
|
tokens[token_id] = token.encode("utf-8")
|
||||||
|
|
||||||
if vocab_size > len(tokens):
|
if vocab_size > len(tokens):
|
||||||
pad_count = vocab_size - len(tokens)
|
pad_count = vocab_size - len(tokens)
|
||||||
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
|
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
|
||||||
|
@ -1266,7 +1310,7 @@ class StableLMModel(Model):
|
||||||
if (self.dir_model / "tokenizer.json").is_file():
|
if (self.dir_model / "tokenizer.json").is_file():
|
||||||
self._set_vocab_gpt2()
|
self._set_vocab_gpt2()
|
||||||
else:
|
else:
|
||||||
# StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
|
# StableLM 2 1.6B used to have a vocab in a similar format to Qwen's vocab
|
||||||
self._set_vocab_qwen()
|
self._set_vocab_qwen()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
|
@ -1578,7 +1622,6 @@ class DbrxModel(Model):
|
||||||
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
|
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
|
||||||
|
|
||||||
self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
|
self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
|
||||||
|
|
||||||
self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
|
self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
|
||||||
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
|
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
|
||||||
|
@ -1872,7 +1915,7 @@ class Phi3MiniModel(Model):
|
||||||
|
|
||||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||||
scores: list[float] = [-10000.0] * vocab_size
|
scores: list[float] = [-10000.0] * vocab_size
|
||||||
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
|
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||||
|
|
||||||
for token_id in range(tokenizer.vocab_size()):
|
for token_id in range(tokenizer.vocab_size()):
|
||||||
|
|
||||||
|
@ -1917,7 +1960,7 @@ class Phi3MiniModel(Model):
|
||||||
for token_id, foken_data in added_tokens_decoder.items():
|
for token_id, foken_data in added_tokens_decoder.items():
|
||||||
token_id = int(token_id)
|
token_id = int(token_id)
|
||||||
token = foken_data["content"].encode("utf-8")
|
token = foken_data["content"].encode("utf-8")
|
||||||
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
|
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
|
||||||
assert tokens[token_id] == token
|
assert tokens[token_id] == token
|
||||||
tokens[token_id] = token
|
tokens[token_id] = token
|
||||||
scores[token_id] = -1000.0
|
scores[token_id] = -1000.0
|
||||||
|
@ -1933,7 +1976,7 @@ class Phi3MiniModel(Model):
|
||||||
for foken_data in added_tokens:
|
for foken_data in added_tokens:
|
||||||
token_id = int(foken_data["id"])
|
token_id = int(foken_data["id"])
|
||||||
token = foken_data["content"].encode("utf-8")
|
token = foken_data["content"].encode("utf-8")
|
||||||
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
|
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
|
||||||
assert tokens[token_id] == token
|
assert tokens[token_id] == token
|
||||||
tokens[token_id] = token
|
tokens[token_id] = token
|
||||||
scores[token_id] = -1000.0
|
scores[token_id] = -1000.0
|
||||||
|
@ -2145,7 +2188,7 @@ class InternLM2Model(Model):
|
||||||
toktype = SentencePieceTokenTypes.BYTE
|
toktype = SentencePieceTokenTypes.BYTE
|
||||||
# take care of ununsed raw token
|
# take care of ununsed raw token
|
||||||
if piece.startswith('[UNUSED'):
|
if piece.startswith('[UNUSED'):
|
||||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
toktype = SentencePieceTokenTypes.UNUSED
|
||||||
|
|
||||||
tokens.append(text)
|
tokens.append(text)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
|
@ -2175,7 +2218,7 @@ class InternLM2Model(Model):
|
||||||
if token == chat_eos_token:
|
if token == chat_eos_token:
|
||||||
chat_eos_token_id = token_id
|
chat_eos_token_id = token_id
|
||||||
token = token.encode("utf-8")
|
token = token.encode("utf-8")
|
||||||
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
|
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
|
||||||
assert(tokens[token_id] == token)
|
assert(tokens[token_id] == token)
|
||||||
tokens[token_id] = token
|
tokens[token_id] = token
|
||||||
scores[token_id] = -1000.0
|
scores[token_id] = -1000.0
|
||||||
|
@ -2194,7 +2237,7 @@ class InternLM2Model(Model):
|
||||||
if token == chat_eos_token:
|
if token == chat_eos_token:
|
||||||
chat_eos_token_id = token_id
|
chat_eos_token_id = token_id
|
||||||
token = token.encode("utf-8")
|
token = token.encode("utf-8")
|
||||||
if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
|
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
|
||||||
assert(tokens[token_id] == token)
|
assert(tokens[token_id] == token)
|
||||||
tokens[token_id] = token
|
tokens[token_id] = token
|
||||||
scores[token_id] = -1000.0
|
scores[token_id] = -1000.0
|
||||||
|
@ -2434,19 +2477,7 @@ class Gemma2Model(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.GEMMA2
|
model_arch = gguf.MODEL_ARCH.GEMMA2
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
tokens, scores, toktypes = self._create_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
# hack: This is required so that we can properly use start/end-of-turn for chat template
|
|
||||||
for i in range(108):
|
|
||||||
# including <unusedX>, <start_of_turn>, <end_of_turn>
|
|
||||||
toktypes[i] = SentencePieceTokenTypes.CONTROL
|
|
||||||
self.gguf_writer.add_tokenizer_model("llama")
|
|
||||||
self.gguf_writer.add_tokenizer_pre("default")
|
|
||||||
self.gguf_writer.add_token_list(tokens)
|
|
||||||
self.gguf_writer.add_token_scores(scores)
|
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
|
||||||
|
|
||||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
|
||||||
|
|
||||||
self.gguf_writer.add_add_space_prefix(False)
|
self.gguf_writer.add_add_space_prefix(False)
|
||||||
|
|
||||||
|
@ -2770,7 +2801,7 @@ class ArcticModel(Model):
|
||||||
|
|
||||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||||
scores: list[float] = [-10000.0] * vocab_size
|
scores: list[float] = [-10000.0] * vocab_size
|
||||||
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
|
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||||
|
|
||||||
for token_id in range(tokenizer.vocab_size()):
|
for token_id in range(tokenizer.vocab_size()):
|
||||||
|
|
||||||
|
@ -3025,7 +3056,7 @@ class T5Model(Model):
|
||||||
|
|
||||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||||
scores: list[float] = [-10000.0] * vocab_size
|
scores: list[float] = [-10000.0] * vocab_size
|
||||||
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
|
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||||
|
|
||||||
for token_id in range(tokenizer.vocab_size()):
|
for token_id in range(tokenizer.vocab_size()):
|
||||||
piece = tokenizer.IdToPiece(token_id)
|
piece = tokenizer.IdToPiece(token_id)
|
||||||
|
@ -3243,15 +3274,14 @@ class ChatGLMModel(Model):
|
||||||
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
|
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
|
||||||
score = tokenizer.tokenizer.sp_model.get_score(token_id)
|
score = tokenizer.tokenizer.sp_model.get_score(token_id)
|
||||||
|
|
||||||
if len(piece) == 0:
|
|
||||||
text = f"[PAD{token_id}]".encode("utf-8")
|
|
||||||
|
|
||||||
if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
|
if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
|
||||||
if piece in special_tokens:
|
if piece in special_tokens:
|
||||||
# show special tokens in prompt
|
toktype = SentencePieceTokenTypes.CONTROL
|
||||||
toktype = SentencePieceTokenTypes.USER_DEFINED
|
elif len(piece) == 0:
|
||||||
|
text = f"[PAD{token_id}]".encode("utf-8")
|
||||||
|
toktype = SentencePieceTokenTypes.UNUSED
|
||||||
else:
|
else:
|
||||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
toktype = SentencePieceTokenTypes.USER_DEFINED
|
||||||
tokens.append(text)
|
tokens.append(text)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
toktypes.append(toktype)
|
toktypes.append(toktype)
|
||||||
|
@ -3340,7 +3370,7 @@ class ChatGLMModel(Model):
|
||||||
for i in range(vocab_size):
|
for i in range(vocab_size):
|
||||||
if i not in reverse_vocab:
|
if i not in reverse_vocab:
|
||||||
tokens.append(f"[PAD{i}]")
|
tokens.append(f"[PAD{i}]")
|
||||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
toktypes.append(gguf.TokenType.UNUSED)
|
||||||
elif reverse_vocab[i] in added_vocab:
|
elif reverse_vocab[i] in added_vocab:
|
||||||
tokens.append(reverse_vocab[i])
|
tokens.append(reverse_vocab[i])
|
||||||
if tokenizer.added_tokens_decoder[i].special:
|
if tokenizer.added_tokens_decoder[i].special:
|
||||||
|
|
|
@ -27,8 +27,9 @@ UUID_NAMESPACE_LLAMA_CPP = uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5')
|
||||||
|
|
||||||
# For more information about what field.parts and field.data represent,
|
# For more information about what field.parts and field.data represent,
|
||||||
# please see the comments in the modify_gguf.py example.
|
# please see the comments in the modify_gguf.py example.
|
||||||
def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None:
|
def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_layer: bool) -> None:
|
||||||
sha1 = hashlib.sha1()
|
sha1 = hashlib.sha1()
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
uuidv5_sha1 = hashlib.sha1()
|
uuidv5_sha1 = hashlib.sha1()
|
||||||
uuidv5_sha1.update(UUID_NAMESPACE_LLAMA_CPP.bytes)
|
uuidv5_sha1.update(UUID_NAMESPACE_LLAMA_CPP.bytes)
|
||||||
|
|
||||||
|
@ -50,7 +51,7 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None:
|
||||||
bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar)
|
bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar)
|
||||||
|
|
||||||
# Hashing Process
|
# Hashing Process
|
||||||
for n, tensor in enumerate(reader.tensors, 1):
|
for tensor in reader.tensors:
|
||||||
|
|
||||||
# We don't need these
|
# We don't need these
|
||||||
if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||||
|
@ -62,29 +63,39 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None:
|
||||||
sum_weights_in_tensor *= dim
|
sum_weights_in_tensor *= dim
|
||||||
bar.update(sum_weights_in_tensor)
|
bar.update(sum_weights_in_tensor)
|
||||||
|
|
||||||
sha1_layer = hashlib.sha1()
|
if not no_layer:
|
||||||
sha1_layer.update(tensor.data.data)
|
|
||||||
|
sha1_layer = hashlib.sha1()
|
||||||
|
sha1_layer.update(tensor.data.data)
|
||||||
|
print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
|
||||||
|
|
||||||
|
sha256_layer = hashlib.sha256()
|
||||||
|
sha256_layer.update(tensor.data.data)
|
||||||
|
print("sha256 {0} {1}:{2}".format(sha256_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
|
||||||
|
|
||||||
sha1.update(tensor.data.data)
|
sha1.update(tensor.data.data)
|
||||||
|
sha256.update(tensor.data.data)
|
||||||
uuidv5_sha1.update(tensor.data.data)
|
uuidv5_sha1.update(tensor.data.data)
|
||||||
print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
|
|
||||||
|
|
||||||
# Flush Hash Progress Bar
|
# Flush Hash Progress Bar
|
||||||
bar.close()
|
bar.close()
|
||||||
|
|
||||||
# Display Hash Output
|
# Display Hash Output
|
||||||
print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100
|
print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100
|
||||||
print("UUIDv5 {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100
|
print("sha256 {0} {1}".format(sha256.hexdigest(), filename)) # noqa: NP100
|
||||||
|
print("uuid {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
|
parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
|
||||||
parser.add_argument("model", type=str, help="GGUF format model filename")
|
parser.add_argument("model", type=str, help="GGUF format model filename")
|
||||||
|
parser.add_argument("--no-layer", action="store_true", help="exclude per layer hash")
|
||||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
parser.add_argument("--progressbar", action="store_true", help="enable progressbar")
|
parser.add_argument("--progressbar", action="store_true", help="enable progressbar")
|
||||||
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
|
||||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
reader = GGUFReader(args.model, 'r')
|
reader = GGUFReader(args.model, 'r')
|
||||||
gguf_hash(reader, args.model, not args.progressbar)
|
gguf_hash(reader, args.model, not args.progressbar, args.no_layer)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -5482,6 +5482,7 @@ static void llm_load_vocab(
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "command-r") {
|
tokenizer_pre == "command-r") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
|
||||||
|
vocab.tokenizer_clean_spaces = false;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "qwen2") {
|
tokenizer_pre == "qwen2") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||||
|
@ -5724,7 +5725,7 @@ static void llm_load_vocab(
|
||||||
// build special tokens cache
|
// build special tokens cache
|
||||||
{
|
{
|
||||||
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
|
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
|
||||||
if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
|
if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
|
||||||
vocab.cache_special_tokens.push_back(id);
|
vocab.cache_special_tokens.push_back(id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15713,17 +15714,6 @@ struct llm_tokenizer_bpe {
|
||||||
"[0-9][0-9][0-9]",
|
"[0-9][0-9][0-9]",
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
|
||||||
// TODO: MPT pre-tokenization regexes are unknown
|
|
||||||
// the following are close, but not exact. run the following:
|
|
||||||
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
|
|
||||||
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
|
|
||||||
regex_exprs = {
|
|
||||||
"\\s?\\p{L}+",
|
|
||||||
"\\s?\\p{P}+",
|
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
|
||||||
};
|
|
||||||
break;
|
|
||||||
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
||||||
|
@ -15733,6 +15723,7 @@ struct llm_tokenizer_bpe {
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_JAIS:
|
case LLAMA_VOCAB_PRE_TYPE_JAIS:
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
|
@ -15759,8 +15750,8 @@ struct llm_tokenizer_bpe {
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_VIKING:
|
case LLAMA_VOCAB_PRE_TYPE_VIKING:
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
"\\p{N}",
|
|
||||||
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
||||||
|
"\\p{N}",
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -16480,12 +16471,20 @@ struct fragment_buffer_variant {
|
||||||
|
|
||||||
// #define PRETOKENIZERDEBUG
|
// #define PRETOKENIZERDEBUG
|
||||||
|
|
||||||
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
|
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
|
||||||
// for each special token
|
// for each special token
|
||||||
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
|
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
|
||||||
const auto & data = vocab.id_to_token[special_id];
|
const auto & data = vocab.id_to_token[special_id];
|
||||||
const auto & special_token = data.text;
|
const auto & special_token = data.text;
|
||||||
|
|
||||||
|
if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
|
||||||
|
// Ignore control and unknown tokens when parse_special == false
|
||||||
|
continue;
|
||||||
|
// User-defined tokens are still pre-tokenized before everything else
|
||||||
|
// ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
|
||||||
|
// This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
// for each text fragment
|
// for each text fragment
|
||||||
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
|
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
|
||||||
while (it != buffer.end()) {
|
while (it != buffer.end()) {
|
||||||
|
@ -16598,7 +16597,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
|
|
||||||
if (!raw_text.empty()) {
|
if (!raw_text.empty()) {
|
||||||
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
|
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
|
||||||
if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
|
tokenizer_st_partition(vocab, fragment_buffer, parse_special);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (vocab.type) {
|
switch (vocab.type) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue