blt/setup/download_tokenizer.py
2024-12-12 15:32:30 -08:00

61 lines
1.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import argparse
import os
from typing import Optional
from requests.exceptions import HTTPError
TOKENIZER = {
"llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"),
"llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"),
"gemma": ("google/gemma-2-9b", "tokenizer.model"),
}
def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None):
if tokenizer_name in TOKENIZER:
repo_id, filename = TOKENIZER[tokenizer_name]
from huggingface_hub import hf_hub_download
try:
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=path_to_save,
local_dir_use_symlinks=False,
token=api_key if api_key else None,
)
except HTTPError as e:
if e.response.status_code == 401:
print(
"You need to pass a valid `--hf_token=...` to download private checkpoints."
)
else:
raise e
else:
from tiktoken import get_encoding
if "TIKTOKEN_CACHE_DIR" not in os.environ:
os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save
try:
get_encoding(tokenizer_name)
except ValueError:
print(
f"Tokenizer {tokenizer_name} not found. Please check the name and try again."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("tokenizer_name", type=str)
parser.add_argument("tokenizer_dir", type=str, default=8)
parser.add_argument("--api_key", type=str, default="")
args = parser.parse_args()
main(
tokenizer_name=args.tokenizer_name,
path_to_save=args.tokenizer_dir,
api_key=args.api_key,
)