move args out

This commit is contained in:
ita.zaporozhets@huggingface.co 2025-06-03 15:29:01 +00:00
parent 4f86b6e7ab
commit c2108e7256
3 changed files with 351 additions and 414 deletions

View file

@ -56,7 +56,7 @@ def sample_top_p(probs, p):
@torch.inference_mode()
def generate_nocache(
def generate(
prompts: list[str] | None,
*,
model: ByteLatentTransformer,
@ -186,9 +186,8 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"):
# Generate text
print("Generating text...")
prompts = [prompt]
outputs = generate_nocache(
prompts,
outputs = generate(
[prompt],
model=model,
tokenizer=tokenizer,
patcher=model.patcher, # Use the model's patcher
@ -197,7 +196,7 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"):
# Decode and print results
text_outputs = [tokenizer.decode(t) for t in outputs]
for p, t in zip(prompts, text_outputs):
for p, t in zip([prompt], text_outputs):
print(f'Prompt: "{p}"')
print(f'Completion: "{t}"')
print()