mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
[feature] release 0.1.3
This commit is contained in:
parent
67f8b370c3
commit
4d1d561d28
58 changed files with 11709 additions and 374 deletions
|
@ -84,7 +84,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
|||
else:
|
||||
module.load()
|
||||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True):
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||
mode = 'normal'):
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
|
@ -110,7 +111,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False, use_cache=True)[0]
|
||||
past_key_values.change_seq_length(1)
|
||||
if past_key_values != None:
|
||||
past_key_values.change_seq_length(1)
|
||||
for device in all_cuda_device:
|
||||
torch.cuda.synchronize(device)
|
||||
#print(logits)
|
||||
|
@ -125,18 +127,26 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
torch.cuda.set_device(torch_device)
|
||||
with torch.no_grad():
|
||||
stream = TextStreamer(tokenizer)
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
)
|
||||
if mode != 'long_context':
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
)
|
||||
else:
|
||||
past_key_values = None
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
|
||||
past_key_values.cur_idx=cache_position
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position
|
||||
start_time = time.time()
|
||||
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||
if mode == "long_context":
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||
else:
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||
logits = model(
|
||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||
|
@ -184,7 +194,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
tokens.append(next_token.int())
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id:
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue