[feature] release 0.1.3

This commit is contained in:
chenxl 2024-08-28 16:11:43 +00:00
parent 67f8b370c3
commit 4d1d561d28
58 changed files with 11709 additions and 374 deletions

View file

@ -84,7 +84,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
else:
module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True):
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal'):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
@ -110,7 +111,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
past_key_values.change_seq_length(1)
if past_key_values != None:
past_key_values.change_seq_length(1)
for device in all_cuda_device:
torch.cuda.synchronize(device)
#print(logits)
@ -125,18 +127,26 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
torch.cuda.set_device(torch_device)
with torch.no_grad():
stream = TextStreamer(tokenizer)
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
if mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
else:
past_key_values = None
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
past_key_values.cur_idx=cache_position
if past_key_values != None:
past_key_values.cur_idx=cache_position
start_time = time.time()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if mode == "long_context":
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
@ -184,7 +194,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens.append(next_token.int())
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id:
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else: