mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
8db6a4d402
43 changed files with 1504 additions and 156 deletions
|
@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
|||
|
||||
warm_uped = False
|
||||
|
||||
def get_compute_capability(device:torch.device = None):
|
||||
if torch.cuda.is_available():
|
||||
if device is None:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
min_compute_capability_major = 100
|
||||
for gpu_id in range(num_gpus):
|
||||
gpu_props = torch.cuda.get_device_properties(gpu_id)
|
||||
min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
|
||||
return min_compute_capability_major
|
||||
else:
|
||||
return torch.cuda.get_device_properties(device)
|
||||
|
||||
def set_module(model, submodule_key, module):
|
||||
tokens = submodule_key.split('.')
|
||||
sub_tokens = tokens[:-1]
|
||||
|
@ -66,13 +78,22 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
|
|||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
translated_key = translate_name_to_gguf(key)
|
||||
if translated_key in gguf_loader.tensor_file_map:
|
||||
|
||||
# TODO: Merge all loader.
|
||||
# I know this is ugly but lets do it for now.
|
||||
if gguf_loader.safetensor_loader is not None:
|
||||
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
|
||||
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
|
||||
else:
|
||||
load_dequantized_tensor = gguf_loader.load_gguf_tensor
|
||||
tensor_file_map = gguf_loader.tensor_file_map
|
||||
|
||||
if translated_key in tensor_file_map:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
torch.cuda.empty_cache()
|
||||
# device = "cpu" if "embd" in translated_key else "cuda"
|
||||
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||
set_param(module, name, weights)
|
||||
del weights
|
||||
else:
|
||||
|
@ -154,6 +175,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||
else:
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
|
||||
logits = model(
|
||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||
|
@ -176,6 +201,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
else:
|
||||
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||
first_token_time = time.time() - start_time
|
||||
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
|
||||
prefill_count = seq_length
|
||||
prefill_time = first_token_time
|
||||
|
@ -193,15 +221,15 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
|
||||
start_time = time.time()
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||
global warm_uped
|
||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
if i > 1 and use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue