mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
[ADD] support multi-gpu qlen>1 q5_k
This commit is contained in:
parent
f293803156
commit
f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions
|
@ -39,6 +39,22 @@ def set_param(module: nn.Module, name: str, weights: torch.Tensor):
|
|||
param.unsqueeze_(0)
|
||||
setattr(module, name, param)
|
||||
|
||||
def get_device(gguf_module_key:str, device_map:dict):
|
||||
if gguf_module_key in device_map:
|
||||
return device_map[gguf_module_key]["generate_device"]
|
||||
else:
|
||||
return "cuda"
|
||||
|
||||
def get_all_used_cuda_device(device_map:dict):
|
||||
all_device_list = set()
|
||||
for key in device_map:
|
||||
all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
|
||||
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
|
||||
if "cpu" in all_device_list:
|
||||
all_device_list.remove("cpu")
|
||||
all_device_list = list(all_device_list)
|
||||
return all_device_list
|
||||
|
||||
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
|
||||
prefix = prefix.replace("orig_module.", "")
|
||||
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
||||
|
@ -47,18 +63,19 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
|
|||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
translated_key = translate_name_to_gguf(key)
|
||||
print("default loading weights", key, translated_key)
|
||||
if translated_key in gguf_loader.tensor_file_map:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = "cpu" if "embd" in translated_key else "cuda"
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
# device = "cpu" if "embd" in translated_key else "cuda"
|
||||
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
|
||||
set_param(module, name, weights)
|
||||
del weights
|
||||
else:
|
||||
#print(load_config.tensor_file_map.keys())
|
||||
raise Exception(f"can't fand {translated_key} in GGUF file!")
|
||||
raise Exception(f"can't find {translated_key} in GGUF file!")
|
||||
|
||||
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_when_injected:bool = False, only_load_injected:bool = False):
|
||||
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
||||
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}")
|
||||
if not isinstance(module, base_operator.BaseInjectedModule):
|
||||
load_cur_state_dict(module, gguf_loader, prefix)
|
||||
|
@ -66,27 +83,36 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_whe
|
|||
load_weights(child, gguf_loader, prefix+name+".")
|
||||
else:
|
||||
module.load()
|
||||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
|
||||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True):
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
batch_size, seq_length = inputs.shape
|
||||
torch_device = inputs.device
|
||||
device_map = model.config.gguf_loader.tensor_device_map
|
||||
torch_device = get_device('blk.0.self_attn', device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
inputs = inputs.to(torch_device)
|
||||
all_cuda_device = get_all_used_cuda_device(device_map)
|
||||
|
||||
tokens = []
|
||||
|
||||
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values):
|
||||
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
|
||||
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
|
||||
if use_cuda_graph:
|
||||
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
|
||||
else:
|
||||
# custom_stream = torch.cuda.Stream()
|
||||
torch.cuda.set_device(torch_device)
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
|
||||
# with torch.cuda.stream(custom_stream):
|
||||
logits=model(inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False, use_cache=True)[0]
|
||||
past_key_values.change_seq_length(1)
|
||||
"""
|
||||
with torch.cuda.stream(custom_stream):
|
||||
logits=model(cur_token,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False, use_cache=True)[0]
|
||||
#"""
|
||||
torch.cuda.synchronize()
|
||||
for device in all_cuda_device:
|
||||
torch.cuda.synchronize(device)
|
||||
#print(logits)
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
if generation_config.do_sample:
|
||||
|
@ -95,11 +121,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
|
|||
else:
|
||||
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||
return next_token
|
||||
|
||||
|
||||
torch.cuda.set_device(torch_device)
|
||||
with torch.no_grad():
|
||||
stream = TextStreamer(tokenizer)
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = torch_device, dtype = model.dtype
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
|
@ -108,23 +135,22 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
|
|||
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
|
||||
past_key_values.cur_idx=cache_position
|
||||
start_time = time.time()
|
||||
#custom_stream = torch.cuda.Stream()
|
||||
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to("cuda")
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||
logits = model(
|
||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||
)[0][:,-1,:].unsqueeze(0).clone()
|
||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||
generation_config, model_kwargs = model._prepare_generation_config(
|
||||
None, max_length=max_new_tokens,
|
||||
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
|
||||
)
|
||||
try: # transformers==4.43
|
||||
logits_warper = (
|
||||
model._get_logits_warper(generation_config,device=inputs.device) if generation_config.do_sample else None
|
||||
model._get_logits_warper(generation_config,device=inputs.device)
|
||||
)
|
||||
except:
|
||||
logits_warper = (
|
||||
model._get_logits_warper(generation_config) if generation_config.do_sample else None
|
||||
model._get_logits_warper(generation_config)
|
||||
)
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
if generation_config.do_sample:
|
||||
|
@ -136,7 +162,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
|
|||
|
||||
prefill_count = seq_length
|
||||
prefill_time = first_token_time
|
||||
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
generated_ids[:, seq_length] = next_token
|
||||
tokens.append(next_token)
|
||||
|
@ -144,12 +169,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
|
|||
cache_position = torch.tensor([seq_length], device=torch_device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
seq_length += 1
|
||||
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, return_dict=False, use_cache=True)
|
||||
|
||||
if use_cuda_graph:
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
else:
|
||||
cuda_graph_runner = None
|
||||
|
||||
start_time = time.time()
|
||||
for _ in range(1, max_new_tokens):
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(next_token.int())
|
||||
|
@ -162,6 +191,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
|
|||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
|
||||
total_time = time.time() - start_time
|
||||
tokens_generated = len(tokens)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue