Update wfgy_core.py

This commit is contained in:
PSBigBig 2025-06-11 17:09:35 +08:00 committed by GitHub
parent dc5fe33251
commit 3ef3f795c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10,12 +10,11 @@ class WFGYRunner:
self.model_id = model_id
if self.use_remote:
token_path = os.path.expanduser("~/.huggingface/token")
if not os.path.exists(token_path):
raise RuntimeError("Please run `huggingface-cli login` before using remote mode.")
with open(token_path, "r") as f:
hf_token = f.read().strip()
self.client = InferenceClient(model=self.model_id, token=hf_token)
# Use huggingface-cli login cache without crashing
try:
self.client = InferenceClient(model=self.model_id, token=True)
except Exception as e:
raise RuntimeError(f"Hugging Face login not detected or token invalid: {e}")
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
@ -35,9 +34,17 @@ class WFGYRunner:
print(prompt)
if self.use_remote:
result = self.client.text_generation(prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature)
result = self.client.text_generation(
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature
)
else:
result = self.pipe(prompt, max_new_tokens=max_new_tokens, temperature=temperature)[0]["generated_text"]
result = self.pipe(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature
)[0]["generated_text"]
print("\n=== Output ===")
print(result.strip())