mirror of
https://github.com/unslothai/unsloth.git
synced 2026-04-28 03:19:57 +00:00
Auto Healing Tokenizer (#283)
* Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * llama * Update llama.py * gemma * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update save.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update gemma.py * correct_dtype * Update gemma.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Chat Templates * Update README.md * Update README.md * Update llama.py * DoRA * Update _utils.py * Update chat_templates.py * Update llama.py * Hotfix - fix DoRA, Gemma prompt template (#202) (#203) * Update save.py * saving * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update __init__.py * Update save.py * Update save.py * Update save.py * save * trainer * spaces * original * Gemma * Update pyproject.toml * Update mapper.py * Update fast_lora.py * FastGemmaModel * model_type * Update llama.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update fast_lora.py * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * gemma * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update fast_lora.py * Update fast_lora.py * Fast CE Loss * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * CE * Update llama.py * Update llama.py * Update cross_entropy_loss.py * Update geglu.py * Update cross_entropy_loss.py * revert * Update llama.py * Update llama.py * norm * Update gemma.py * Update gemma.py * position_ids * Update gemma.py * Update gemma.py * pos * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * revert * revert * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * rope * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * llama * Update llama.py * gemma * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update save.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update gemma.py * correct_dtype * Update gemma.py * Update cross_entropy_loss.py * Update cross_entropy_loss.py * Chat Templates * Update README.md * Update README.md * Update llama.py * DoRA * Update _utils.py * Update chat_templates.py * Update pyproject.toml * Small fixes * Update pyproject.toml * Approx gelu * Update geglu.py * Approx gelu * Update llama.py * Update __init__.py * Update __init__.py * Update _utils.py * Update geglu.py * Update gemma.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Update gemma.py * Fix Gemma merging * Update rms_layernorm.py * Update gemma.py * Update pyproject.toml * Layernorms * Gemma precision * Update gemma.py * sqrt * Update gemma.py * Update save.py * RoPE and Gemma precision * Update rms_layernorm.py * Fix warning * Update chat_templates.py * Update chat_templates.py * Update save.py * Update save.py * Update save.py * Update chat_templates.py * Update llama.py * model_name * Update loader.py * Tokenizer overwritten * Update llama.py * Update llama.py * Update llama.py * Update save.py * Accuracy * Revert * Update save.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update fast_lora.py * Update chat_templates.py * Update save.py * Update save.py * Update llama.py * Update llama.py * Account for DoRA * Update llama.py * Update save.py * GGUF incorrect * Update save.py * Update pyproject.toml * kaggle new * Update pyproject.toml * Update pyproject.toml * upcasting * Fix Colab * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update rope_embedding.py * Update rope_embedding.py * Fix bugs * Update fast_lora.py * Update fast_lora.py * Update README.md * Update README.md * GGUF * Update save.py * Update save.py * Update save.py * Update save.py * Update README.md * Update README.md * Bugs * Update fast_lora.py * Update pyproject.toml * Update fast_lora.py * Update __init__.py * Update fast_lora.py * dtype * Update llama.py * Update llama.py * Update llama.py * dtype * Update mistral.py * trust_remote_code * lm_head * Update llama.py * save_pretrained_settings * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * state_dict * Update save.py * whoami * Update llama.py * Update save.py * Update llama.py * Patch tokenizer * Update chat_templates.py * Heal tokenizers * Update chat_templates.py * Update mapper.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update chat_templates.py * tokenizer patching * patch_tokenizer * Update chat_templates.py * Update tokenizer_utils.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update tokenizer_utils.py * Edit
This commit is contained in:
parent
bb45fdabb6
commit
74f79684da
9 changed files with 464 additions and 345 deletions
160
.gitignore
vendored
160
.gitignore
vendored
|
|
@ -1,160 +0,0 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
2
LICENSE
2
LICENSE
|
|
@ -186,7 +186,7 @@
|
|||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
|
|||
|
|
@ -113,3 +113,4 @@ pass
|
|||
from .models import *
|
||||
from .save import *
|
||||
from .chat_templates import *
|
||||
from .tokenizer_utils import *
|
||||
|
|
|
|||
|
|
@ -15,15 +15,19 @@
|
|||
__all__ = [
|
||||
"get_chat_template",
|
||||
"test_chat_templates",
|
||||
"fix_sentencepiece_tokenizer",
|
||||
]
|
||||
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
from torch import LongTensor, FloatTensor
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
from .models._utils import patch_tokenizer
|
||||
from .save import patch_saving_functions
|
||||
import os
|
||||
import shutil
|
||||
from .tokenizer_utils import (
|
||||
load_correct_tokenizer,
|
||||
fix_sentencepiece_tokenizer,
|
||||
)
|
||||
from .models._utils import patch_tokenizer
|
||||
|
||||
CHAT_TEMPLATES = {}
|
||||
|
||||
|
|
@ -251,84 +255,23 @@ gemma_chatml_eos_token = (
|
|||
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token,)
|
||||
|
||||
|
||||
def fix_sentencepiece_tokenizer(
|
||||
old_tokenizer,
|
||||
new_tokenizer,
|
||||
token_mapping,
|
||||
temporary_location = "_unsloth_sentencepiece_temp",
|
||||
):
|
||||
# From https://github.com/google/sentencepiece/issues/121
|
||||
# We need to manually edit the sentencepiece tokenizer!
|
||||
try:
|
||||
import sentencepiece.sentencepiece_model_pb2 as sentencepiece_model_pb2
|
||||
except:
|
||||
if not os.path.exists(temporary_location):
|
||||
os.system("git clone https://github.com/google/sentencepiece.git unsloth_sentencepiece_temp")
|
||||
os.system(f"cd {temporary_location}/src && protoc --python_out=. sentencepiece_model.proto")
|
||||
shutil.rmtree(temporary_location)
|
||||
pass
|
||||
import sentencepiece.sentencepiece_model_pb2 as sentencepiece_model_pb2
|
||||
pass
|
||||
|
||||
if not os.path.exists(temporary_location):
|
||||
os.makedirs(temporary_location)
|
||||
pass
|
||||
|
||||
# First save the old tokenizer
|
||||
old_tokenizer.save_pretrained(temporary_location)
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
tokenizer_file = sentencepiece_model_pb2.ModelProto()
|
||||
tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read())
|
||||
|
||||
# Now save the new tokenizer
|
||||
new_tokenizer.save_pretrained(temporary_location)
|
||||
|
||||
# Now correct the old tokenizer's .model file
|
||||
for old_token, new_token in token_mapping.items():
|
||||
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
|
||||
ids = ids[0]
|
||||
if (len(ids) != 1):
|
||||
# Skip this token!
|
||||
print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!")
|
||||
continue
|
||||
pass
|
||||
ids = ids[0]
|
||||
tokenizer_piece = tokenizer_file.pieces[ids]
|
||||
assert(tokenizer_piece.piece == old_token)
|
||||
tokenizer_piece.piece = new_token
|
||||
pass
|
||||
|
||||
# And now write it
|
||||
with open(f"{temporary_location}/tokenizer.model", "wb") as file:
|
||||
file.write(tokenizer_file.SerializeToString())
|
||||
pass
|
||||
|
||||
# And load it!
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(temporary_location, eos_token = new_tokenizer.eos_token)
|
||||
return tokenizer
|
||||
pass
|
||||
|
||||
|
||||
def get_chat_template(
|
||||
tokenizer,
|
||||
chat_template = "chatml",
|
||||
mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
|
||||
map_eos_token = True,
|
||||
):
|
||||
assert(type(map_eos_token) is bool)
|
||||
old_tokenizer = tokenizer
|
||||
|
||||
if map_eos_token is False:
|
||||
assert("Unsloth: Can only map new tokens to EOS for now. Adding new tokens is not yet supported.")
|
||||
pass
|
||||
|
||||
IS_GEMMA = False
|
||||
if tokenizer.__class__.__name__.startswith("Gemma"):
|
||||
if chat_template == "chatml": chat_template = "gemma_chatml"
|
||||
IS_GEMMA = True
|
||||
pass
|
||||
|
||||
# We first check if the tokenizer is a fast one. If not, we cannot convert this!
|
||||
is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
|
||||
old_padding_side = tokenizer.padding_side
|
||||
|
||||
if type(chat_template) in (list, tuple,):
|
||||
|
|
@ -348,9 +291,17 @@ def get_chat_template(
|
|||
|
||||
assert(type(stop_word) is str)
|
||||
|
||||
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
|
||||
# For Gemma :)
|
||||
if token_mapping is not None:
|
||||
# Check fast tokenizer
|
||||
if not is_fast_tokenizer:
|
||||
logger.warning_once(
|
||||
f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
|
||||
"Please log a Github issue if you want this as a new feature!\n"\
|
||||
"Your chat template will still work, but it won't add or edit tokens."
|
||||
)
|
||||
|
||||
elif token_mapping is not None:
|
||||
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
|
||||
# For Gemma :)
|
||||
|
||||
string_vocab = tokenizer._tokenizer.to_str()
|
||||
|
||||
|
|
@ -368,7 +319,7 @@ def get_chat_template(
|
|||
pass
|
||||
pass
|
||||
|
||||
if not stop_word in token_mapping.values():
|
||||
if map_eos_token and (not stop_word in token_mapping.values()):
|
||||
# Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
|
||||
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
|
||||
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
|
||||
|
|
@ -376,14 +327,19 @@ def get_chat_template(
|
|||
|
||||
if skipped != len(token_mapping):
|
||||
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
|
||||
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
|
||||
|
||||
if map_eos_token:
|
||||
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
|
||||
else:
|
||||
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer)
|
||||
pass
|
||||
|
||||
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
|
||||
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
|
||||
else:
|
||||
pass
|
||||
|
||||
elif stop_word != "eos_token":
|
||||
elif map_eos_token and (stop_word != "eos_token"):
|
||||
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
|
||||
|
||||
# Replaces the old EOS token with a new one.
|
||||
|
|
@ -393,9 +349,14 @@ def get_chat_template(
|
|||
# This is a HACK!
|
||||
# Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
|
||||
string_vocab = tokenizer._tokenizer.to_str()
|
||||
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
|
||||
old_eos_token = tokenizer.eos_token
|
||||
string_vocab = string_vocab.replace(old_eos_token, stop_word)
|
||||
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
|
||||
tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
|
||||
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
|
||||
|
||||
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
|
||||
token_mapping = { old_eos_token : stop_word, }
|
||||
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
|
||||
pass
|
||||
|
||||
else:
|
||||
|
|
@ -433,7 +394,10 @@ def get_chat_template(
|
|||
if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
|
||||
if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
|
||||
|
||||
#stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
|
||||
# stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
|
||||
|
||||
# Patch saving functions
|
||||
tokenizer = patch_saving_functions(tokenizer)
|
||||
|
||||
return tokenizer#, stopping_criteria
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -60,22 +60,16 @@ from xformers import __version__ as xformers_version
|
|||
|
||||
__all__ = [
|
||||
"prepare_model_for_kbit_training",
|
||||
"patch_tokenizer",
|
||||
"check_tokenizer",
|
||||
"xformers",
|
||||
"xformers_attention",
|
||||
"xformers_version",
|
||||
"__version__",
|
||||
"HAS_FLASH_ATTENTION",
|
||||
"platform_system",
|
||||
"patch_tokenizer",
|
||||
]
|
||||
|
||||
|
||||
IGNORED_TOKENIZER_CHECKING = frozenset((
|
||||
"CodeLlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
))
|
||||
|
||||
def prepare_model_for_kbit_training(
|
||||
model : Any,
|
||||
use_gradient_checkpointing : bool = True,
|
||||
|
|
@ -144,103 +138,6 @@ def patch_tokenizer(model, tokenizer):
|
|||
pass
|
||||
|
||||
|
||||
def check_tokenizer(
|
||||
model,
|
||||
tokenizer,
|
||||
model_name = "unsloth/llama-2-7b-bnb-4bit",
|
||||
model_max_length = 4096,
|
||||
padding_side = "right",
|
||||
token = None,
|
||||
_reload = True,
|
||||
):
|
||||
# Checks tokenizer for out of bounds ids.
|
||||
# Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
|
||||
# where <sep> had token id=32002.
|
||||
# See https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha/discussions/25
|
||||
# Seems like the Fast tokenizer in Rust breaks things!
|
||||
|
||||
# We ignore some of them!
|
||||
if tokenizer.__repr__().split("(", 1)[0] in IGNORED_TOKENIZER_CHECKING:
|
||||
return tokenizer
|
||||
pass
|
||||
|
||||
max_embedding_size = model.model.embed_tokens.weight.shape[0]
|
||||
added_tokens_fast = tokenizer.added_tokens_decoder
|
||||
added_tokens_fast = {index : str(value) for index, value in added_tokens_fast.items()}
|
||||
sorted_keys = sorted(added_tokens_fast)
|
||||
added_tokens_fast = {key : added_tokens_fast[key] for key in sorted_keys}
|
||||
|
||||
for j, index in enumerate(added_tokens_fast.keys()):
|
||||
if index >= max_embedding_size:
|
||||
bad_indices = list(added_tokens_fast.keys ())[j:]
|
||||
bad_tokens = list(added_tokens_fast.values())[j:]
|
||||
|
||||
if not _reload:
|
||||
# Try removing the token
|
||||
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
|
||||
special_tokens = tokenizer.special_tokens_map
|
||||
import itertools
|
||||
special_tokens = frozenset(
|
||||
itertools.chain.from_iterable(
|
||||
[x] if type(x) is str else x for x in special_tokens.values()
|
||||
)
|
||||
)
|
||||
can_be_removed1 = [x for x in bad_tokens if x not in special_tokens]
|
||||
can_be_removed2 = [x for x in can_be_removed1 if x in tokenizer._added_tokens_encoder.keys()]
|
||||
|
||||
# Check of extra tokens can in fact we removed!
|
||||
|
||||
if (len(can_be_removed1) == len(bad_tokens)) and \
|
||||
(len(can_be_removed2) == len(bad_tokens)):
|
||||
# Yes it can be fixed!
|
||||
for bad_token in can_be_removed1:
|
||||
remove_id = tokenizer._added_tokens_encoder[bad_token]
|
||||
del tokenizer._added_tokens_decoder[remove_id]
|
||||
del tokenizer._added_tokens_encoder[bad_token]
|
||||
pass
|
||||
# Confirm 1 more time!
|
||||
if max(tokenizer.added_tokens_decoder.keys()) < max_embedding_size:
|
||||
logger.warning_once(
|
||||
f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"\
|
||||
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
|
||||
"We removed these bad tokens. If you think this is incorrect, fix your tokenizer first."
|
||||
)
|
||||
return tokenizer
|
||||
pass
|
||||
pass
|
||||
|
||||
# :( Failure
|
||||
raise RuntimeError(
|
||||
f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"\
|
||||
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
|
||||
f"Fix your tokenizer since it'll perform out of bounds memory accesses."
|
||||
)
|
||||
pass
|
||||
|
||||
# Try slow tokenizer which can fix things!
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
use_fast = False,
|
||||
)
|
||||
return check_tokenizer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
_reload = False,
|
||||
)
|
||||
break
|
||||
pass
|
||||
pass
|
||||
return tokenizer
|
||||
pass
|
||||
|
||||
|
||||
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
|
||||
# For mixed precision, we need it to be in float32 not float16.
|
||||
from peft.tuners.lora.layer import LoraLayer
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from transformers.modeling_attn_mask_utils import (
|
|||
from ..kernels import *
|
||||
from ._utils import *
|
||||
from ._utils import __version__
|
||||
from ..tokenizer_utils import *
|
||||
if HAS_FLASH_ATTENTION:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
|
|
@ -1014,8 +1015,8 @@ class FastLlamaModel:
|
|||
|
||||
# Counteract saved tokenizers
|
||||
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
tokenizer = load_correct_tokenizer(
|
||||
tokenizer_name = tokenizer_name,
|
||||
model_max_length = max_position_embeddings,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
|
|
|
|||
|
|
@ -362,7 +362,7 @@ class FastMistralModel(FastLlamaModel):
|
|||
|
||||
# Counteract saved tokenizers
|
||||
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer = load_correct_tokenizer(
|
||||
tokenizer_name,
|
||||
model_max_length = max_position_embeddings,
|
||||
padding_side = "right",
|
||||
|
|
|
|||
|
|
@ -276,7 +276,8 @@ def unsloth_save_model(
|
|||
old_username = None, private = private,
|
||||
)
|
||||
|
||||
model.original_push_to_hub(
|
||||
getattr(model, "original_push_to_hub", tokenizer.push_to_hub)\
|
||||
(
|
||||
repo_id = save_directory,
|
||||
use_temp_dir = use_temp_dir,
|
||||
commit_message = commit_message,
|
||||
|
|
@ -290,7 +291,8 @@ def unsloth_save_model(
|
|||
tags = tags,
|
||||
)
|
||||
if tokenizer is not None:
|
||||
tokenizer.original_push_to_hub(
|
||||
getattr(tokenizer, "original_push_to_hub", tokenizer.push_to_hub)\
|
||||
(
|
||||
repo_id = save_directory,
|
||||
use_temp_dir = use_temp_dir,
|
||||
commit_message = commit_message,
|
||||
|
|
|
|||
414
unsloth/tokenizer_utils.py
Normal file
414
unsloth/tokenizer_utils.py
Normal file
|
|
@ -0,0 +1,414 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
import re
|
||||
import os
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
|
||||
__all__ = [
|
||||
"load_correct_tokenizer",
|
||||
"fix_sentencepiece_tokenizer",
|
||||
"check_tokenizer",
|
||||
]
|
||||
|
||||
|
||||
IGNORED_TOKENIZER_CHECKING = frozenset((
|
||||
"CodeLlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
))
|
||||
|
||||
|
||||
def try_fix_tokenizer(tokenizer, prepend = True):
|
||||
|
||||
if hasattr(tokenizer, "_tokenizer"):
|
||||
converted_tokenizer = tokenizer._tokenizer
|
||||
else:
|
||||
converted_tokenizer = convert_slow_tokenizer(tokenizer)
|
||||
pass
|
||||
|
||||
tokenizer_string = converted_tokenizer.to_str()
|
||||
|
||||
# Llama does ▁apple. Sometimes this is wrong!!
|
||||
prepend_text = '{"type":"Prepend","prepend":"▁"},'
|
||||
if not prepend and prepend_text in tokenizer_string:
|
||||
tokenizer_string = tokenizer_string.replace(prepend_text, "", 1)
|
||||
pass
|
||||
|
||||
dir_names = dir(tokenizer)
|
||||
# Get eos_token, bos_token etc
|
||||
token_names = [x for x in dir_names if x.endswith("_token") and x.count("_") == 1]
|
||||
|
||||
for token_name in token_names:
|
||||
token = getattr(tokenizer, token_name, None)
|
||||
if token is None: continue
|
||||
token_id = getattr(tokenizer, token_name + "_id", None)
|
||||
|
||||
# Locate the token's id mapping in the string
|
||||
find_text = f'"id":{token_id},"content":"'
|
||||
start = tokenizer_string.find(find_text) + len(find_text)
|
||||
if start == -1: continue
|
||||
end = tokenizer_string.find('",', start)
|
||||
|
||||
bad_token = tokenizer_string[start : end]
|
||||
# Check if token is the actual same one - if not, edit it
|
||||
if bad_token != token:
|
||||
bad_text = f'{find_text}{bad_token}",'
|
||||
good_text = f'{find_text}{token}",'
|
||||
tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
|
||||
|
||||
# And replace vocab section
|
||||
bad_text = f'"{bad_token}":{token_id},'
|
||||
good_text = f'"{token}":{token_id},'
|
||||
tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
|
||||
pass
|
||||
pass
|
||||
|
||||
fixed_tokenizer = converted_tokenizer.from_str(tokenizer_string)
|
||||
return fixed_tokenizer
|
||||
pass
|
||||
|
||||
|
||||
def get_sorted_dict(dictionary):
|
||||
sorted_keys = sorted(dictionary.values())
|
||||
inverted_dictionary = { value : key for key, value in dictionary.items() }
|
||||
|
||||
sorted_dictionary = {}
|
||||
for key in sorted_keys:
|
||||
value = inverted_dictionary[key]
|
||||
sorted_dictionary[value] = key
|
||||
return sorted_dictionary
|
||||
pass
|
||||
|
||||
|
||||
def convert_to_fast_tokenizer(
|
||||
slow_tokenizer,
|
||||
temporary_location = "_unsloth_sentencepiece_temp",
|
||||
):
|
||||
is_fast = getattr(slow_tokenizer, "is_fast", False)
|
||||
if is_fast: return slow_tokenizer
|
||||
|
||||
try:
|
||||
tokenizer_name = slow_tokenizer.__class__.__name__
|
||||
lowered_tokenizer_name = tokenizer_name.lower()
|
||||
if lowered_tokenizer_name.endswith("tokenizer"):
|
||||
class_name = lowered_tokenizer_name[:-len("tokenizer")]
|
||||
FastTokenizer = eval(
|
||||
f'__import__(f"transformers.models.{class_name}").{tokenizer_name}Fast'
|
||||
)
|
||||
else:
|
||||
FastTokenizer = PreTrainedTokenizerFast
|
||||
except:
|
||||
FastTokenizer = PreTrainedTokenizerFast
|
||||
pass
|
||||
|
||||
# Get all arguments (bos_token, etc)
|
||||
docs = FastTokenizer.__doc__
|
||||
docs = docs[docs.find("Args:"):]
|
||||
args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
|
||||
args = [x for x in args if not x.endswith("_file")]
|
||||
|
||||
# Also some missing maybe!
|
||||
docs = PreTrainedTokenizerFast.__doc__
|
||||
docs = docs[docs.find("Args:"):]
|
||||
args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
|
||||
args2 = [x for x in args2 if not x.endswith("_file")]
|
||||
args = list(set(args + args2))
|
||||
|
||||
kwargs = {}
|
||||
for arg in args: kwargs[arg] = getattr(slow_tokenizer, arg, None)
|
||||
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = True)
|
||||
fast_tokenizer = FastTokenizer( **kwargs )
|
||||
|
||||
# Check if they're similar!
|
||||
sorted_slow_tokenizer = get_sorted_dict(slow_tokenizer.get_vocab())
|
||||
sorted_fast_tokenizer = get_sorted_dict(fast_tokenizer.get_vocab())
|
||||
|
||||
check_vocab = (sorted_slow_tokenizer == sorted_fast_tokenizer)
|
||||
check_special = (slow_tokenizer.all_special_tokens == fast_tokenizer.all_special_tokens)
|
||||
|
||||
# Failure so return slow_tokenizer
|
||||
if not check_vocab or not check_special: return slow_tokenizer
|
||||
|
||||
# Now confirm if they match
|
||||
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
|
||||
# Maybe remove prepending of __apple?
|
||||
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = False)
|
||||
fast_tokenizer = FastTokenizer( **kwargs )
|
||||
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
|
||||
# Failure :(
|
||||
return slow_tokenizer
|
||||
pass
|
||||
pass
|
||||
|
||||
# Also tokenizer.model is missing!
|
||||
name = slow_tokenizer.name_or_path.replace("/", "_")
|
||||
if not os.path.exists(temporary_location):
|
||||
os.makedirs(temporary_location)
|
||||
pass
|
||||
new_location = f"{temporary_location}/{name}"
|
||||
slow_tokenizer.save_pretrained(new_location)
|
||||
fast_tokenizer.save_pretrained(new_location)
|
||||
|
||||
# Now load it!
|
||||
fast_tokenizer = AutoTokenizer.from_pretrained(new_location)
|
||||
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
|
||||
return fast_tokenizer
|
||||
return slow_tokenizer
|
||||
pass
|
||||
|
||||
|
||||
def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
|
||||
# Get eos_token, bos_token etc
|
||||
dir_names = dir(slow_tokenizer)
|
||||
special_tokens = list(filter(None, (
|
||||
getattr(slow_tokenizer, x) for x in dir_names
|
||||
if x.endswith("_token") and x.count("_") == 1
|
||||
)))
|
||||
all_special_tokens = list(set(special_tokens + slow_tokenizer.all_special_tokens))
|
||||
string = "\n".join(all_special_tokens) + \
|
||||
"A quick brown fox jumps over the lazy dog!!\n\n" + \
|
||||
"".join(all_special_tokens)
|
||||
return slow_tokenizer(string).input_ids == fast_tokenizer(string).input_ids
|
||||
pass
|
||||
|
||||
|
||||
global sentencepiece_model_pb2
|
||||
sentencepiece_model_pb2 = None
|
||||
|
||||
def fix_sentencepiece_tokenizer(
|
||||
old_tokenizer,
|
||||
new_tokenizer,
|
||||
token_mapping,
|
||||
temporary_location = "_unsloth_sentencepiece_temp",
|
||||
):
|
||||
# From https://github.com/google/sentencepiece/issues/121
|
||||
# We need to manually edit the sentencepiece tokenizer!
|
||||
global sentencepiece_model_pb2
|
||||
if sentencepiece_model_pb2 is None:
|
||||
try:
|
||||
import sentencepiece.sentencepiece_model_pb2 as _sentencepiece_model_pb2
|
||||
sentencepiece_model_pb2 = _sentencepiece_model_pb2
|
||||
except:
|
||||
if not os.path.exists(temporary_location):
|
||||
os.system(f"git clone https://github.com/google/sentencepiece.git {temporary_location}")
|
||||
os.system(f"cd {temporary_location}/src && protoc --python_out=. sentencepiece_model.proto")
|
||||
pass
|
||||
import sentencepiece.sentencepiece_model_pb2 as _sentencepiece_model_pb2
|
||||
sentencepiece_model_pb2 = _sentencepiece_model_pb2
|
||||
pass
|
||||
|
||||
if not os.path.exists(temporary_location):
|
||||
os.makedirs(temporary_location)
|
||||
pass
|
||||
|
||||
# First save the old tokenizer
|
||||
old_tokenizer.save_pretrained(temporary_location)
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
tokenizer_file = sentencepiece_model_pb2.ModelProto()
|
||||
tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read())
|
||||
|
||||
# Now save the new tokenizer
|
||||
new_tokenizer.save_pretrained(temporary_location)
|
||||
|
||||
# Now correct the old tokenizer's .model file
|
||||
for old_token, new_token in token_mapping.items():
|
||||
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
|
||||
ids = ids[0]
|
||||
if (len(ids) != 1):
|
||||
# Skip this token!
|
||||
print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!")
|
||||
continue
|
||||
pass
|
||||
ids = ids[0]
|
||||
tokenizer_piece = tokenizer_file.pieces[ids]
|
||||
assert(tokenizer_piece.piece == old_token)
|
||||
tokenizer_piece.piece = new_token
|
||||
pass
|
||||
|
||||
# And now write it
|
||||
with open(f"{temporary_location}/tokenizer.model", "wb") as file:
|
||||
file.write(tokenizer_file.SerializeToString())
|
||||
pass
|
||||
|
||||
# And load it!
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(temporary_location, eos_token = new_tokenizer.eos_token)
|
||||
return tokenizer
|
||||
pass
|
||||
|
||||
|
||||
def load_correct_tokenizer(
|
||||
tokenizer_name,
|
||||
model_max_length = None,
|
||||
padding_side = "right",
|
||||
token = None,
|
||||
trust_remote_code = False,
|
||||
):
|
||||
slow_tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
use_fast = False,
|
||||
)
|
||||
fast_tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
fast_tokenizer.add_bos_token = slow_tokenizer.add_bos_token
|
||||
fast_tokenizer.add_eos_token = slow_tokenizer.add_eos_token
|
||||
|
||||
# Confirm if slow and fast are equivalent!
|
||||
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
|
||||
return fast_tokenizer
|
||||
else:
|
||||
return convert_to_fast_tokenizer(slow_tokenizer)
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def check_tokenizer(
|
||||
model,
|
||||
tokenizer,
|
||||
model_name = "unsloth/llama-2-7b-bnb-4bit",
|
||||
model_max_length = 4096,
|
||||
padding_side = "right",
|
||||
token = None,
|
||||
_reload = True,
|
||||
):
|
||||
# Checks tokenizer for out of bounds ids.
|
||||
# Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
|
||||
# where <sep> had token id=32002.
|
||||
# See https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha/discussions/25
|
||||
# Seems like the Fast tokenizer in Rust breaks things!
|
||||
|
||||
# We ignore some of them!
|
||||
if tokenizer.__repr__().split("(", 1)[0] in IGNORED_TOKENIZER_CHECKING:
|
||||
return tokenizer
|
||||
pass
|
||||
|
||||
max_embedding_size = model.model.embed_tokens.weight.shape[0]
|
||||
added_tokens_fast = tokenizer.added_tokens_decoder
|
||||
added_tokens_fast = {index : str(value) for index, value in added_tokens_fast.items()}
|
||||
sorted_keys = sorted(added_tokens_fast)
|
||||
added_tokens_fast = {key : added_tokens_fast[key] for key in sorted_keys}
|
||||
|
||||
for j, index in enumerate(added_tokens_fast.keys()):
|
||||
if index >= max_embedding_size:
|
||||
bad_indices = list(added_tokens_fast.keys ())[j:]
|
||||
bad_tokens = list(added_tokens_fast.values())[j:]
|
||||
if not _reload:
|
||||
# Try removing the token
|
||||
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
|
||||
special_tokens = tokenizer.special_tokens_map
|
||||
import itertools
|
||||
special_tokens = frozenset(
|
||||
itertools.chain.from_iterable(
|
||||
[x] if type(x) is str else x for x in special_tokens.values()
|
||||
)
|
||||
)
|
||||
can_be_removed1 = [x for x in bad_tokens if x not in special_tokens]
|
||||
can_be_removed2 = [x for x in can_be_removed1 if x in tokenizer._added_tokens_encoder.keys()]
|
||||
|
||||
# Check of extra tokens can in fact we removed!
|
||||
can_be_removed = \
|
||||
(len(can_be_removed1) == len(bad_tokens)) and \
|
||||
(len(can_be_removed2) == len(bad_tokens))
|
||||
|
||||
# Check if sep_token or other generic types
|
||||
remove_generic = False
|
||||
try_mapper = []
|
||||
if not can_be_removed:
|
||||
names = dir(tokenizer)
|
||||
names = (x for x in names if x.endswith("_token") and x.count("_") == 1)
|
||||
generic_tokens = [(x, getattr(tokenizer, x, None)) for x in names]
|
||||
|
||||
try_removal = []
|
||||
for token in bad_tokens:
|
||||
for (name_token, check_token) in generic_tokens:
|
||||
if check_token == token:
|
||||
try_removal.append(token)
|
||||
try_mapper.append(name_token)
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
||||
# Recheck!
|
||||
can_be_removed = (len(try_removal) == len(bad_tokens))
|
||||
if can_be_removed: remove_generic = True
|
||||
can_be_removed1 = bad_tokens
|
||||
pass
|
||||
|
||||
if can_be_removed:
|
||||
# Yes it can be fixed!
|
||||
for j, bad_token in enumerate(can_be_removed1):
|
||||
remove_id = tokenizer._added_tokens_encoder[bad_token]
|
||||
del tokenizer._added_tokens_decoder[remove_id]
|
||||
del tokenizer._added_tokens_encoder[bad_token]
|
||||
|
||||
if remove_generic and (try_removal[j] == bad_token):
|
||||
# Remove sep token for example
|
||||
setattr(tokenizer, try_mapper[j], None)
|
||||
setattr(tokenizer, try_mapper[j] + "_id", None)
|
||||
pass
|
||||
pass
|
||||
# Confirm 1 more time!
|
||||
if max(tokenizer.added_tokens_decoder.keys()) < max_embedding_size:
|
||||
logger.warning_once(
|
||||
f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"\
|
||||
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
|
||||
"We removed these bad tokens. If you think this is incorrect, fix your tokenizer first."
|
||||
)
|
||||
return convert_to_fast_tokenizer(tokenizer)
|
||||
pass
|
||||
pass
|
||||
|
||||
# :( Failure
|
||||
raise RuntimeError(
|
||||
f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"\
|
||||
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
|
||||
f"Fix your tokenizer since it'll perform out of bounds memory accesses."
|
||||
)
|
||||
pass
|
||||
|
||||
# Try slow tokenizer which can fix things!
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
use_fast = False,
|
||||
)
|
||||
return check_tokenizer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
_reload = False,
|
||||
)
|
||||
break
|
||||
pass
|
||||
pass
|
||||
return convert_to_fast_tokenizer(tokenizer)
|
||||
pass
|
||||
Loading…
Add table
Add a link
Reference in a new issue