mirror of
https://github.com/onestardao/WFGY.git
synced 2026-04-28 03:29:51 +00:00
Update wfgy_core.py
This commit is contained in:
parent
9285750c2a
commit
fea00a66fd
1 changed files with 5 additions and 5 deletions
10
wfgy_core.py
10
wfgy_core.py
|
|
@ -4,16 +4,16 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
|||
from huggingface_hub import InferenceClient
|
||||
|
||||
class WFGYRunner:
|
||||
def __init__(self, model_id="HuggingFaceH4/zephyr-7b-alpha", use_remote=True):
|
||||
def __init__(self, model_id="tiiuae/falcon-7b-instruct", use_remote=True):
|
||||
self.use_remote = use_remote
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model_id = model_id
|
||||
|
||||
if self.use_remote:
|
||||
try:
|
||||
self.client = InferenceClient(model=self.model_id, token=os.environ.get("HF_TOKEN"))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Hugging Face remote mode failed: {e}")
|
||||
token = os.environ.get("HF_TOKEN")
|
||||
if not token:
|
||||
raise RuntimeError("Missing HF_TOKEN environment variable.")
|
||||
self.client = InferenceClient(model=self.model_id, token=token)
|
||||
else:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue