mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
249 lines
11 KiB
Python
249 lines
11 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
Description :
|
|
Author : Boxin Zhang, Azure-Tang
|
|
Version : 0.1.0
|
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|
'''
|
|
import torch
|
|
from torch import nn
|
|
import itertools
|
|
import time
|
|
import enum
|
|
from ktransformers.util.custom_gguf import translate_name_to_gguf
|
|
from ktransformers.util.custom_gguf import GGUFLoader
|
|
from ktransformers.operators import base_operator
|
|
from ktransformers.models.custom_cache import StaticCache
|
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
|
from ktransformers.util.textstream import TextStreamer
|
|
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
|
|
|
warm_uped = False
|
|
|
|
def set_module(model, submodule_key, module):
|
|
tokens = submodule_key.split('.')
|
|
sub_tokens = tokens[:-1]
|
|
cur_mod = model
|
|
for s in sub_tokens:
|
|
if hasattr(cur_mod, s):
|
|
cur_mod = getattr(cur_mod, s)
|
|
else: # nn.ModuleList or nn.ModuleList
|
|
cur_mod=cur_mod[int(s)]
|
|
if hasattr(cur_mod, tokens[-1]):
|
|
setattr(cur_mod, tokens[-1], module)
|
|
else: # nn.ModuleList or nn.ModuleList
|
|
cur_mod[int(tokens[-1])] = module
|
|
|
|
def set_param(module: nn.Module, name: str, weights: torch.Tensor):
|
|
|
|
param=nn.parameter.Parameter(weights, requires_grad=False)
|
|
if isinstance(module, nn.Linear) and len(weights.shape)==1:
|
|
param.unsqueeze_(0)
|
|
setattr(module, name, param)
|
|
|
|
def get_device(gguf_module_key:str, device_map:dict):
|
|
if gguf_module_key in device_map:
|
|
return device_map[gguf_module_key]["generate_device"]
|
|
else:
|
|
return "cuda"
|
|
|
|
def get_all_used_cuda_device(device_map:dict):
|
|
all_device_list = set()
|
|
for key in device_map:
|
|
all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
|
|
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
|
|
if "cpu" in all_device_list:
|
|
all_device_list.remove("cpu")
|
|
all_device_list = list(all_device_list)
|
|
return all_device_list
|
|
|
|
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
|
|
prefix = prefix.replace("orig_module.", "")
|
|
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
|
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
for name, param in local_state.items():
|
|
key = prefix + name
|
|
translated_key = translate_name_to_gguf(key)
|
|
|
|
# TODO: Merge all loader.
|
|
# I know this is ugly but lets do it for now.
|
|
if gguf_loader.safetensor_loader is not None:
|
|
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
|
|
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
|
|
else:
|
|
load_dequantized_tensor = gguf_loader.load_gguf_tensor
|
|
tensor_file_map = gguf_loader.tensor_file_map
|
|
|
|
if translated_key in tensor_file_map:
|
|
target_dtype = torch.get_default_dtype()
|
|
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
|
print(f"loading {translated_key} to {device}")
|
|
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225"
|
|
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
|
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
|
set_param(module, name, weights)
|
|
del weights
|
|
else:
|
|
#print(load_config.tensor_file_map.keys())
|
|
raise Exception(f"can't find {translated_key} in GGUF file!")
|
|
|
|
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
|
#print(f"recursively loading weights {prefix}")
|
|
if not isinstance(module, base_operator.BaseInjectedModule):
|
|
load_cur_state_dict(module, gguf_loader, prefix)
|
|
for name, child in module._modules.items():
|
|
load_weights(child, gguf_loader, prefix+name+".")
|
|
else:
|
|
module.load()
|
|
|
|
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
|
mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
|
|
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
|
import os
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
torch._dynamo.config.suppress_errors = True
|
|
batch_size, seq_length = inputs.shape
|
|
device_map = model.gguf_loader.tensor_device_map
|
|
torch_device = get_device('blk.0.self_attn', device_map)
|
|
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
|
inputs = inputs.to(torch_device)
|
|
all_cuda_device = get_all_used_cuda_device(device_map)
|
|
|
|
tokens = []
|
|
|
|
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
|
|
if cuda_graph_runner is None:
|
|
use_cuda_graph = False
|
|
if use_cuda_graph:
|
|
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
|
|
else:
|
|
# custom_stream = torch.cuda.Stream()
|
|
torch.cuda.set_device(torch_device)
|
|
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
|
|
# with torch.cuda.stream(custom_stream):
|
|
logits=model(inputs_embeds=inputs_embeds,
|
|
position_ids=position_ids,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
return_dict=False, use_cache=True)[0]
|
|
if past_key_values != None:
|
|
past_key_values.change_seq_length(1)
|
|
for device in all_cuda_device:
|
|
torch.cuda.synchronize(device)
|
|
#print(logits)
|
|
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
|
if generation_config.do_sample:
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
else:
|
|
next_token = torch.argmax(next_token_scores, dim=-1)
|
|
return next_token
|
|
|
|
torch.cuda.set_device(torch_device)
|
|
with torch.no_grad():
|
|
stream = TextStreamer(tokenizer)
|
|
if mode != 'long_context':
|
|
past_key_values = StaticCache(
|
|
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
|
)
|
|
else:
|
|
past_key_values = None
|
|
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
|
generated_ids = torch.zeros(
|
|
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
|
)
|
|
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
|
|
if past_key_values != None:
|
|
past_key_values.cur_idx=cache_position
|
|
start_time = time.time()
|
|
|
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
|
if mode == "long_context":
|
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
|
else:
|
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
|
logits = model(
|
|
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
|
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
|
generation_config, model_kwargs = model._prepare_generation_config(
|
|
None, max_length=max_new_tokens,
|
|
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
|
|
)
|
|
try: # transformers==4.43
|
|
logits_warper = (
|
|
model._get_logits_warper(generation_config,device=inputs.device)
|
|
)
|
|
except:
|
|
logits_warper = (
|
|
model._get_logits_warper(generation_config)
|
|
)
|
|
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
|
if generation_config.do_sample:
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
else:
|
|
next_token = torch.argmax(next_token_scores, dim=-1)
|
|
first_token_time = time.time() - start_time
|
|
|
|
prefill_count = seq_length
|
|
prefill_time = first_token_time
|
|
if force_think:
|
|
print("<think>\n")
|
|
print(stream.put(next_token.item()), end="", flush=True)
|
|
generated_ids[:, seq_length] = next_token
|
|
tokens.append(int(next_token))
|
|
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
|
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
|
|
position_ids = cache_position.unsqueeze(0)
|
|
seq_length += 1
|
|
|
|
cuda_graph_runner = None
|
|
|
|
start_time = time.time()
|
|
for i in range(1, max_new_tokens):
|
|
global warm_uped
|
|
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
|
warm_uped = True
|
|
cuda_graph_runner = CUDAGraphRunner()
|
|
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
|
if i > 1 and use_flashinfer_mla:
|
|
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
|
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
|
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
|
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
|
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
|
generated_ids[:, cache_position] = next_token.int()
|
|
tokens.append(int(next_token))
|
|
seq_length += 1
|
|
|
|
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token) == '<|im_end|>':
|
|
print(stream.end(), end="", flush=True)
|
|
break
|
|
else:
|
|
print(stream.put(next_token.item()), end="", flush=True)
|
|
cache_position += 1
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
|
|
total_time = time.time() - start_time
|
|
tokens_generated = len(tokens)
|
|
tokens_per_second = tokens_generated / total_time
|
|
|
|
print("")
|
|
|
|
print(f"prompt eval count: {prefill_count} token(s)")
|
|
print(f"prompt eval duration: {prefill_time}s")
|
|
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
|
|
print(f"eval count: {tokens_generated} token(s)")
|
|
print(f"eval duration: {total_time}s")
|
|
print(f"eval rate: {tokens_per_second} tokens/s")
|
|
|
|
return tokens
|
|
|
|
class InferenceState(enum.Enum):
|
|
UNLOAD = 0
|
|
PREFILL = 1
|
|
GENERATE = 2
|
|
RESTORE = 3
|