Auto Healing Tokenizer (#283)

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* llama

* Update llama.py

* gemma

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update save.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update gemma.py

* correct_dtype

* Update gemma.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Chat Templates

* Update README.md

* Update README.md

* Update llama.py

* DoRA

* Update _utils.py

* Update chat_templates.py

* Update llama.py

* Hotfix - fix DoRA, Gemma prompt template (#202) (#203)

* Update save.py

* saving

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update __init__.py

* Update save.py

* Update save.py

* Update save.py

* save

* trainer

* spaces

* original

* Gemma

* Update pyproject.toml

* Update mapper.py

* Update fast_lora.py

* FastGemmaModel

* model_type

* Update llama.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* gemma

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Fast CE Loss

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* CE

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update geglu.py

* Update cross_entropy_loss.py

* revert

* Update llama.py

* Update llama.py

* norm

* Update gemma.py

* Update gemma.py

* position_ids

* Update gemma.py

* Update gemma.py

* pos

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* revert

* revert

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* rope

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* llama

* Update llama.py

* gemma

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update save.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update gemma.py

* correct_dtype

* Update gemma.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Chat Templates

* Update README.md

* Update README.md

* Update llama.py

* DoRA

* Update _utils.py

* Update chat_templates.py

* Update pyproject.toml

* Small fixes

* Update pyproject.toml

* Approx gelu

* Update geglu.py

* Approx gelu

* Update llama.py

* Update __init__.py

* Update __init__.py

* Update _utils.py

* Update geglu.py

* Update gemma.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update rms_layernorm.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Fix Gemma merging

* Update rms_layernorm.py

* Update gemma.py

* Update pyproject.toml

* Layernorms

* Gemma precision

* Update gemma.py

* sqrt

* Update gemma.py

* Update save.py

* RoPE and Gemma precision

* Update rms_layernorm.py

* Fix warning

* Update chat_templates.py

* Update chat_templates.py

* Update save.py

* Update save.py

* Update save.py

* Update chat_templates.py

* Update llama.py

* model_name

* Update loader.py

* Tokenizer overwritten

* Update llama.py

* Update llama.py

* Update llama.py

* Update save.py

* Accuracy

* Revert

* Update save.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update chat_templates.py

* Update save.py

* Update save.py

* Update llama.py

* Update llama.py

* Account for DoRA

* Update llama.py

* Update save.py

* GGUF incorrect

* Update save.py

* Update pyproject.toml

* kaggle new

* Update pyproject.toml

* Update pyproject.toml

* upcasting

* Fix Colab

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update rope_embedding.py

* Update rope_embedding.py

* Fix bugs

* Update fast_lora.py

* Update fast_lora.py

* Update README.md

* Update README.md

* GGUF

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update README.md

* Update README.md

* Bugs

* Update fast_lora.py

* Update pyproject.toml

* Update fast_lora.py

* Update __init__.py

* Update fast_lora.py

* dtype

* Update llama.py

* Update llama.py

* Update llama.py

* dtype

* Update mistral.py

* trust_remote_code

* lm_head

* Update llama.py

* save_pretrained_settings

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* state_dict

* Update save.py

* whoami

* Update llama.py

* Update save.py

* Update llama.py

* Patch tokenizer

* Update chat_templates.py

* Heal tokenizers

* Update chat_templates.py

* Update mapper.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update chat_templates.py

* tokenizer patching

* patch_tokenizer

* Update chat_templates.py

* Update tokenizer_utils.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update tokenizer_utils.py

* Edit
This commit is contained in:
Daniel Han 2024-03-28 04:16:50 +11:00 committed by GitHub
parent bb45fdabb6
commit 74f79684da
9 changed files with 464 additions and 345 deletions

160
.gitignore vendored
View file

@ -1,160 +0,0 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View file

@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View file

@ -113,3 +113,4 @@ pass
from .models import *
from .save import *
from .chat_templates import *
from .tokenizer_utils import *

View file

@ -15,15 +15,19 @@
__all__ = [
"get_chat_template",
"test_chat_templates",
"fix_sentencepiece_tokenizer",
]
from transformers import StoppingCriteria, StoppingCriteriaList
from torch import LongTensor, FloatTensor
from transformers.models.llama.modeling_llama import logger
from .models._utils import patch_tokenizer
from .save import patch_saving_functions
import os
import shutil
from .tokenizer_utils import (
load_correct_tokenizer,
fix_sentencepiece_tokenizer,
)
from .models._utils import patch_tokenizer
CHAT_TEMPLATES = {}
@ -251,84 +255,23 @@ gemma_chatml_eos_token = (
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token,)
def fix_sentencepiece_tokenizer(
old_tokenizer,
new_tokenizer,
token_mapping,
temporary_location = "_unsloth_sentencepiece_temp",
):
# From https://github.com/google/sentencepiece/issues/121
# We need to manually edit the sentencepiece tokenizer!
try:
import sentencepiece.sentencepiece_model_pb2 as sentencepiece_model_pb2
except:
if not os.path.exists(temporary_location):
os.system("git clone https://github.com/google/sentencepiece.git unsloth_sentencepiece_temp")
os.system(f"cd {temporary_location}/src && protoc --python_out=. sentencepiece_model.proto")
shutil.rmtree(temporary_location)
pass
import sentencepiece.sentencepiece_model_pb2 as sentencepiece_model_pb2
pass
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
pass
# First save the old tokenizer
old_tokenizer.save_pretrained(temporary_location)
from sentencepiece import SentencePieceProcessor
tokenizer_file = sentencepiece_model_pb2.ModelProto()
tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read())
# Now save the new tokenizer
new_tokenizer.save_pretrained(temporary_location)
# Now correct the old tokenizer's .model file
for old_token, new_token in token_mapping.items():
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
ids = ids[0]
if (len(ids) != 1):
# Skip this token!
print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!")
continue
pass
ids = ids[0]
tokenizer_piece = tokenizer_file.pieces[ids]
assert(tokenizer_piece.piece == old_token)
tokenizer_piece.piece = new_token
pass
# And now write it
with open(f"{temporary_location}/tokenizer.model", "wb") as file:
file.write(tokenizer_file.SerializeToString())
pass
# And load it!
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(temporary_location, eos_token = new_tokenizer.eos_token)
return tokenizer
pass
def get_chat_template(
tokenizer,
chat_template = "chatml",
mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
map_eos_token = True,
):
assert(type(map_eos_token) is bool)
old_tokenizer = tokenizer
if map_eos_token is False:
assert("Unsloth: Can only map new tokens to EOS for now. Adding new tokens is not yet supported.")
pass
IS_GEMMA = False
if tokenizer.__class__.__name__.startswith("Gemma"):
if chat_template == "chatml": chat_template = "gemma_chatml"
IS_GEMMA = True
pass
# We first check if the tokenizer is a fast one. If not, we cannot convert this!
is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
old_padding_side = tokenizer.padding_side
if type(chat_template) in (list, tuple,):
@ -348,9 +291,17 @@ def get_chat_template(
assert(type(stop_word) is str)
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
# For Gemma :)
if token_mapping is not None:
# Check fast tokenizer
if not is_fast_tokenizer:
logger.warning_once(
f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
"Please log a Github issue if you want this as a new feature!\n"\
"Your chat template will still work, but it won't add or edit tokens."
)
elif token_mapping is not None:
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
# For Gemma :)
string_vocab = tokenizer._tokenizer.to_str()
@ -368,7 +319,7 @@ def get_chat_template(
pass
pass
if not stop_word in token_mapping.values():
if map_eos_token and (not stop_word in token_mapping.values()):
# Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
@ -376,14 +327,19 @@ def get_chat_template(
if skipped != len(token_mapping):
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
if map_eos_token:
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
else:
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer)
pass
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
else:
pass
elif stop_word != "eos_token":
elif map_eos_token and (stop_word != "eos_token"):
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
# Replaces the old EOS token with a new one.
@ -393,9 +349,14 @@ def get_chat_template(
# This is a HACK!
# Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
string_vocab = tokenizer._tokenizer.to_str()
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
old_eos_token = tokenizer.eos_token
string_vocab = string_vocab.replace(old_eos_token, stop_word)
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
token_mapping = { old_eos_token : stop_word, }
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
pass
else:
@ -433,7 +394,10 @@ def get_chat_template(
if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
#stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
# stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
# Patch saving functions
tokenizer = patch_saving_functions(tokenizer)
return tokenizer#, stopping_criteria
pass

View file

@ -60,22 +60,16 @@ from xformers import __version__ as xformers_version
__all__ = [
"prepare_model_for_kbit_training",
"patch_tokenizer",
"check_tokenizer",
"xformers",
"xformers_attention",
"xformers_version",
"__version__",
"HAS_FLASH_ATTENTION",
"platform_system",
"patch_tokenizer",
]
IGNORED_TOKENIZER_CHECKING = frozenset((
"CodeLlamaTokenizerFast",
"CodeLlamaTokenizer",
))
def prepare_model_for_kbit_training(
model : Any,
use_gradient_checkpointing : bool = True,
@ -144,103 +138,6 @@ def patch_tokenizer(model, tokenizer):
pass
def check_tokenizer(
model,
tokenizer,
model_name = "unsloth/llama-2-7b-bnb-4bit",
model_max_length = 4096,
padding_side = "right",
token = None,
_reload = True,
):
# Checks tokenizer for out of bounds ids.
# Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
# where <sep> had token id=32002.
# See https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha/discussions/25
# Seems like the Fast tokenizer in Rust breaks things!
# We ignore some of them!
if tokenizer.__repr__().split("(", 1)[0] in IGNORED_TOKENIZER_CHECKING:
return tokenizer
pass
max_embedding_size = model.model.embed_tokens.weight.shape[0]
added_tokens_fast = tokenizer.added_tokens_decoder
added_tokens_fast = {index : str(value) for index, value in added_tokens_fast.items()}
sorted_keys = sorted(added_tokens_fast)
added_tokens_fast = {key : added_tokens_fast[key] for key in sorted_keys}
for j, index in enumerate(added_tokens_fast.keys()):
if index >= max_embedding_size:
bad_indices = list(added_tokens_fast.keys ())[j:]
bad_tokens = list(added_tokens_fast.values())[j:]
if not _reload:
# Try removing the token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
special_tokens = tokenizer.special_tokens_map
import itertools
special_tokens = frozenset(
itertools.chain.from_iterable(
[x] if type(x) is str else x for x in special_tokens.values()
)
)
can_be_removed1 = [x for x in bad_tokens if x not in special_tokens]
can_be_removed2 = [x for x in can_be_removed1 if x in tokenizer._added_tokens_encoder.keys()]
# Check of extra tokens can in fact we removed!
if (len(can_be_removed1) == len(bad_tokens)) and \
(len(can_be_removed2) == len(bad_tokens)):
# Yes it can be fixed!
for bad_token in can_be_removed1:
remove_id = tokenizer._added_tokens_encoder[bad_token]
del tokenizer._added_tokens_decoder[remove_id]
del tokenizer._added_tokens_encoder[bad_token]
pass
# Confirm 1 more time!
if max(tokenizer.added_tokens_decoder.keys()) < max_embedding_size:
logger.warning_once(
f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"\
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
"We removed these bad tokens. If you think this is incorrect, fix your tokenizer first."
)
return tokenizer
pass
pass
# :( Failure
raise RuntimeError(
f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"\
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
f"Fix your tokenizer since it'll perform out of bounds memory accesses."
)
pass
# Try slow tokenizer which can fix things!
tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
use_fast = False,
)
return check_tokenizer(
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
_reload = False,
)
break
pass
pass
return tokenizer
pass
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
# For mixed precision, we need it to be in float32 not float16.
from peft.tuners.lora.layer import LoraLayer

View file

@ -26,6 +26,7 @@ from transformers.modeling_attn_mask_utils import (
from ..kernels import *
from ._utils import *
from ._utils import __version__
from ..tokenizer_utils import *
if HAS_FLASH_ATTENTION:
from flash_attn import flash_attn_func
@ -1014,8 +1015,8 @@ class FastLlamaModel:
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
tokenizer = load_correct_tokenizer(
tokenizer_name = tokenizer_name,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,

View file

@ -362,7 +362,7 @@ class FastMistralModel(FastLlamaModel):
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
tokenizer = AutoTokenizer.from_pretrained(
tokenizer = load_correct_tokenizer(
tokenizer_name,
model_max_length = max_position_embeddings,
padding_side = "right",

View file

@ -276,7 +276,8 @@ def unsloth_save_model(
old_username = None, private = private,
)
model.original_push_to_hub(
getattr(model, "original_push_to_hub", tokenizer.push_to_hub)\
(
repo_id = save_directory,
use_temp_dir = use_temp_dir,
commit_message = commit_message,
@ -290,7 +291,8 @@ def unsloth_save_model(
tags = tags,
)
if tokenizer is not None:
tokenizer.original_push_to_hub(
getattr(tokenizer, "original_push_to_hub", tokenizer.push_to_hub)\
(
repo_id = save_directory,
use_temp_dir = use_temp_dir,
commit_message = commit_message,

414
unsloth/tokenizer_utils.py Normal file
View file

@ -0,0 +1,414 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import PreTrainedTokenizerFast
import re
import os
from transformers.models.llama.modeling_llama import logger
__all__ = [
"load_correct_tokenizer",
"fix_sentencepiece_tokenizer",
"check_tokenizer",
]
IGNORED_TOKENIZER_CHECKING = frozenset((
"CodeLlamaTokenizerFast",
"CodeLlamaTokenizer",
))
def try_fix_tokenizer(tokenizer, prepend = True):
if hasattr(tokenizer, "_tokenizer"):
converted_tokenizer = tokenizer._tokenizer
else:
converted_tokenizer = convert_slow_tokenizer(tokenizer)
pass
tokenizer_string = converted_tokenizer.to_str()
# Llama does ▁apple. Sometimes this is wrong!!
prepend_text = '{"type":"Prepend","prepend":""},'
if not prepend and prepend_text in tokenizer_string:
tokenizer_string = tokenizer_string.replace(prepend_text, "", 1)
pass
dir_names = dir(tokenizer)
# Get eos_token, bos_token etc
token_names = [x for x in dir_names if x.endswith("_token") and x.count("_") == 1]
for token_name in token_names:
token = getattr(tokenizer, token_name, None)
if token is None: continue
token_id = getattr(tokenizer, token_name + "_id", None)
# Locate the token's id mapping in the string
find_text = f'"id":{token_id},"content":"'
start = tokenizer_string.find(find_text) + len(find_text)
if start == -1: continue
end = tokenizer_string.find('",', start)
bad_token = tokenizer_string[start : end]
# Check if token is the actual same one - if not, edit it
if bad_token != token:
bad_text = f'{find_text}{bad_token}",'
good_text = f'{find_text}{token}",'
tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
# And replace vocab section
bad_text = f'"{bad_token}":{token_id},'
good_text = f'"{token}":{token_id},'
tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
pass
pass
fixed_tokenizer = converted_tokenizer.from_str(tokenizer_string)
return fixed_tokenizer
pass
def get_sorted_dict(dictionary):
sorted_keys = sorted(dictionary.values())
inverted_dictionary = { value : key for key, value in dictionary.items() }
sorted_dictionary = {}
for key in sorted_keys:
value = inverted_dictionary[key]
sorted_dictionary[value] = key
return sorted_dictionary
pass
def convert_to_fast_tokenizer(
slow_tokenizer,
temporary_location = "_unsloth_sentencepiece_temp",
):
is_fast = getattr(slow_tokenizer, "is_fast", False)
if is_fast: return slow_tokenizer
try:
tokenizer_name = slow_tokenizer.__class__.__name__
lowered_tokenizer_name = tokenizer_name.lower()
if lowered_tokenizer_name.endswith("tokenizer"):
class_name = lowered_tokenizer_name[:-len("tokenizer")]
FastTokenizer = eval(
f'__import__(f"transformers.models.{class_name}").{tokenizer_name}Fast'
)
else:
FastTokenizer = PreTrainedTokenizerFast
except:
FastTokenizer = PreTrainedTokenizerFast
pass
# Get all arguments (bos_token, etc)
docs = FastTokenizer.__doc__
docs = docs[docs.find("Args:"):]
args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
args = [x for x in args if not x.endswith("_file")]
# Also some missing maybe!
docs = PreTrainedTokenizerFast.__doc__
docs = docs[docs.find("Args:"):]
args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
args2 = [x for x in args2 if not x.endswith("_file")]
args = list(set(args + args2))
kwargs = {}
for arg in args: kwargs[arg] = getattr(slow_tokenizer, arg, None)
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = True)
fast_tokenizer = FastTokenizer( **kwargs )
# Check if they're similar!
sorted_slow_tokenizer = get_sorted_dict(slow_tokenizer.get_vocab())
sorted_fast_tokenizer = get_sorted_dict(fast_tokenizer.get_vocab())
check_vocab = (sorted_slow_tokenizer == sorted_fast_tokenizer)
check_special = (slow_tokenizer.all_special_tokens == fast_tokenizer.all_special_tokens)
# Failure so return slow_tokenizer
if not check_vocab or not check_special: return slow_tokenizer
# Now confirm if they match
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Maybe remove prepending of __apple?
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = False)
fast_tokenizer = FastTokenizer( **kwargs )
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Failure :(
return slow_tokenizer
pass
pass
# Also tokenizer.model is missing!
name = slow_tokenizer.name_or_path.replace("/", "_")
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
pass
new_location = f"{temporary_location}/{name}"
slow_tokenizer.save_pretrained(new_location)
fast_tokenizer.save_pretrained(new_location)
# Now load it!
fast_tokenizer = AutoTokenizer.from_pretrained(new_location)
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return fast_tokenizer
return slow_tokenizer
pass
def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Get eos_token, bos_token etc
dir_names = dir(slow_tokenizer)
special_tokens = list(filter(None, (
getattr(slow_tokenizer, x) for x in dir_names
if x.endswith("_token") and x.count("_") == 1
)))
all_special_tokens = list(set(special_tokens + slow_tokenizer.all_special_tokens))
string = "\n".join(all_special_tokens) + \
"A quick brown fox jumps over the lazy dog!!\n\n" + \
"".join(all_special_tokens)
return slow_tokenizer(string).input_ids == fast_tokenizer(string).input_ids
pass
global sentencepiece_model_pb2
sentencepiece_model_pb2 = None
def fix_sentencepiece_tokenizer(
old_tokenizer,
new_tokenizer,
token_mapping,
temporary_location = "_unsloth_sentencepiece_temp",
):
# From https://github.com/google/sentencepiece/issues/121
# We need to manually edit the sentencepiece tokenizer!
global sentencepiece_model_pb2
if sentencepiece_model_pb2 is None:
try:
import sentencepiece.sentencepiece_model_pb2 as _sentencepiece_model_pb2
sentencepiece_model_pb2 = _sentencepiece_model_pb2
except:
if not os.path.exists(temporary_location):
os.system(f"git clone https://github.com/google/sentencepiece.git {temporary_location}")
os.system(f"cd {temporary_location}/src && protoc --python_out=. sentencepiece_model.proto")
pass
import sentencepiece.sentencepiece_model_pb2 as _sentencepiece_model_pb2
sentencepiece_model_pb2 = _sentencepiece_model_pb2
pass
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
pass
# First save the old tokenizer
old_tokenizer.save_pretrained(temporary_location)
from sentencepiece import SentencePieceProcessor
tokenizer_file = sentencepiece_model_pb2.ModelProto()
tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read())
# Now save the new tokenizer
new_tokenizer.save_pretrained(temporary_location)
# Now correct the old tokenizer's .model file
for old_token, new_token in token_mapping.items():
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
ids = ids[0]
if (len(ids) != 1):
# Skip this token!
print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!")
continue
pass
ids = ids[0]
tokenizer_piece = tokenizer_file.pieces[ids]
assert(tokenizer_piece.piece == old_token)
tokenizer_piece.piece = new_token
pass
# And now write it
with open(f"{temporary_location}/tokenizer.model", "wb") as file:
file.write(tokenizer_file.SerializeToString())
pass
# And load it!
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(temporary_location, eos_token = new_tokenizer.eos_token)
return tokenizer
pass
def load_correct_tokenizer(
tokenizer_name,
model_max_length = None,
padding_side = "right",
token = None,
trust_remote_code = False,
):
slow_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
use_fast = False,
)
fast_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
)
fast_tokenizer.add_bos_token = slow_tokenizer.add_bos_token
fast_tokenizer.add_eos_token = slow_tokenizer.add_eos_token
# Confirm if slow and fast are equivalent!
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return fast_tokenizer
else:
return convert_to_fast_tokenizer(slow_tokenizer)
pass
pass
def check_tokenizer(
model,
tokenizer,
model_name = "unsloth/llama-2-7b-bnb-4bit",
model_max_length = 4096,
padding_side = "right",
token = None,
_reload = True,
):
# Checks tokenizer for out of bounds ids.
# Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
# where <sep> had token id=32002.
# See https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha/discussions/25
# Seems like the Fast tokenizer in Rust breaks things!
# We ignore some of them!
if tokenizer.__repr__().split("(", 1)[0] in IGNORED_TOKENIZER_CHECKING:
return tokenizer
pass
max_embedding_size = model.model.embed_tokens.weight.shape[0]
added_tokens_fast = tokenizer.added_tokens_decoder
added_tokens_fast = {index : str(value) for index, value in added_tokens_fast.items()}
sorted_keys = sorted(added_tokens_fast)
added_tokens_fast = {key : added_tokens_fast[key] for key in sorted_keys}
for j, index in enumerate(added_tokens_fast.keys()):
if index >= max_embedding_size:
bad_indices = list(added_tokens_fast.keys ())[j:]
bad_tokens = list(added_tokens_fast.values())[j:]
if not _reload:
# Try removing the token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
special_tokens = tokenizer.special_tokens_map
import itertools
special_tokens = frozenset(
itertools.chain.from_iterable(
[x] if type(x) is str else x for x in special_tokens.values()
)
)
can_be_removed1 = [x for x in bad_tokens if x not in special_tokens]
can_be_removed2 = [x for x in can_be_removed1 if x in tokenizer._added_tokens_encoder.keys()]
# Check of extra tokens can in fact we removed!
can_be_removed = \
(len(can_be_removed1) == len(bad_tokens)) and \
(len(can_be_removed2) == len(bad_tokens))
# Check if sep_token or other generic types
remove_generic = False
try_mapper = []
if not can_be_removed:
names = dir(tokenizer)
names = (x for x in names if x.endswith("_token") and x.count("_") == 1)
generic_tokens = [(x, getattr(tokenizer, x, None)) for x in names]
try_removal = []
for token in bad_tokens:
for (name_token, check_token) in generic_tokens:
if check_token == token:
try_removal.append(token)
try_mapper.append(name_token)
pass
pass
pass
# Recheck!
can_be_removed = (len(try_removal) == len(bad_tokens))
if can_be_removed: remove_generic = True
can_be_removed1 = bad_tokens
pass
if can_be_removed:
# Yes it can be fixed!
for j, bad_token in enumerate(can_be_removed1):
remove_id = tokenizer._added_tokens_encoder[bad_token]
del tokenizer._added_tokens_decoder[remove_id]
del tokenizer._added_tokens_encoder[bad_token]
if remove_generic and (try_removal[j] == bad_token):
# Remove sep token for example
setattr(tokenizer, try_mapper[j], None)
setattr(tokenizer, try_mapper[j] + "_id", None)
pass
pass
# Confirm 1 more time!
if max(tokenizer.added_tokens_decoder.keys()) < max_embedding_size:
logger.warning_once(
f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"\
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
"We removed these bad tokens. If you think this is incorrect, fix your tokenizer first."
)
return convert_to_fast_tokenizer(tokenizer)
pass
pass
# :( Failure
raise RuntimeError(
f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"\
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
f"Fix your tokenizer since it'll perform out of bounds memory accesses."
)
pass
# Try slow tokenizer which can fix things!
tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
use_fast = False,
)
return check_tokenizer(
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
_reload = False,
)
break
pass
pass
return convert_to_fast_tokenizer(tokenizer)
pass