mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 06:14:35 +00:00
move args out
This commit is contained in:
parent
4f86b6e7ab
commit
c2108e7256
3 changed files with 351 additions and 414 deletions
|
@ -56,7 +56,7 @@ def sample_top_p(probs, p):
|
|||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_nocache(
|
||||
def generate(
|
||||
prompts: list[str] | None,
|
||||
*,
|
||||
model: ByteLatentTransformer,
|
||||
|
@ -186,9 +186,8 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"):
|
|||
|
||||
# Generate text
|
||||
print("Generating text...")
|
||||
prompts = [prompt]
|
||||
outputs = generate_nocache(
|
||||
prompts,
|
||||
outputs = generate(
|
||||
[prompt],
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
patcher=model.patcher, # Use the model's patcher
|
||||
|
@ -197,7 +196,7 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"):
|
|||
|
||||
# Decode and print results
|
||||
text_outputs = [tokenizer.decode(t) for t in outputs]
|
||||
for p, t in zip(prompts, text_outputs):
|
||||
for p, t in zip([prompt], text_outputs):
|
||||
print(f'Prompt: "{p}"')
|
||||
print(f'Completion: "{t}"')
|
||||
print()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue