mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 22:34:37 +00:00
Initial commit
This commit is contained in:
commit
bcc039bb75
86 changed files with 12203 additions and 0 deletions
60
setup/download_tokenizer.py
Normal file
60
setup/download_tokenizer.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
TOKENIZER = {
|
||||
"llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"),
|
||||
"llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"),
|
||||
"gemma": ("google/gemma-2-9b", "tokenizer.model"),
|
||||
}
|
||||
|
||||
|
||||
def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None):
|
||||
if tokenizer_name in TOKENIZER:
|
||||
repo_id, filename = TOKENIZER[tokenizer_name]
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
local_dir=path_to_save,
|
||||
local_dir_use_symlinks=False,
|
||||
token=api_key if api_key else None,
|
||||
)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
print(
|
||||
"You need to pass a valid `--hf_token=...` to download private checkpoints."
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
from tiktoken import get_encoding
|
||||
|
||||
if "TIKTOKEN_CACHE_DIR" not in os.environ:
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save
|
||||
try:
|
||||
get_encoding(tokenizer_name)
|
||||
except ValueError:
|
||||
print(
|
||||
f"Tokenizer {tokenizer_name} not found. Please check the name and try again."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("tokenizer_name", type=str)
|
||||
parser.add_argument("tokenizer_dir", type=str, default=8)
|
||||
parser.add_argument("--api_key", type=str, default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(
|
||||
tokenizer_name=args.tokenizer_name,
|
||||
path_to_save=args.tokenizer_dir,
|
||||
api_key=args.api_key,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue