mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first
This commit is contained in:
parent
333351c7c8
commit
142fb7ce6c
22 changed files with 673 additions and 81 deletions
|
@ -27,7 +27,8 @@ from ktransformers.operators import base_operator
|
|||
from ktransformers.models.custom_cache import StaticCache
|
||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.util.textstream import TextStreamer
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
if not torch.xpu.is_available():
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
import socket
|
||||
|
||||
warm_uped = False
|
||||
|
@ -59,6 +60,8 @@ def get_compute_capability(device:torch.device = None):
|
|||
return min_compute_capability_major
|
||||
else:
|
||||
return torch.cuda.get_device_properties(device)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def set_module(model, submodule_key, module):
|
||||
tokens = submodule_key.split('.')
|
||||
|
@ -97,7 +100,7 @@ def get_all_used_cuda_device(device_map:dict):
|
|||
all_device_list = list(all_device_list)
|
||||
return all_device_list
|
||||
|
||||
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = ""):
|
||||
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
|
||||
prefix = prefix.replace("orig_module.", "")
|
||||
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
|
||||
|
@ -118,7 +121,10 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
|||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||
set_param(module, name, weights)
|
||||
del weights
|
||||
|
@ -126,12 +132,24 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
|||
#print(load_config.tensor_file_map.keys())
|
||||
raise Exception(f"can't find {translated_key} in GGUF file!")
|
||||
|
||||
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix=''):
|
||||
|
||||
def sync_all_device(all_device_list):
|
||||
for device in all_device_list:
|
||||
if "cuda" in device.lower():
|
||||
torch.cuda.synchronize(device)
|
||||
elif "xpu" in device.lower():
|
||||
torch.xpu.synchronize(device)
|
||||
else:
|
||||
raise RuntimeError("The device {} is not available".format(device))
|
||||
|
||||
torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
|
||||
|
||||
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
|
||||
#print(f"recursively loading weights {prefix}")
|
||||
if not isinstance(module, base_operator.BaseInjectedModule):
|
||||
load_cur_state_dict(module, gguf_loader, prefix)
|
||||
load_cur_state_dict(module, gguf_loader, prefix, device=device)
|
||||
for name, child in module._modules.items():
|
||||
load_weights(child, gguf_loader, prefix+name+".")
|
||||
load_weights(child, gguf_loader, prefix+name+".", device=device)
|
||||
else:
|
||||
module.load()
|
||||
|
||||
|
@ -194,8 +212,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
torch._dynamo.config.suppress_errors = True
|
||||
batch_size, seq_length = inputs.shape
|
||||
device_map = model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device('blk.0.self_attn', device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
torch_device = get_device('model.layers.0.self_attn', device_map)
|
||||
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
|
||||
inputs = inputs.to(torch_device)
|
||||
all_cuda_device = get_all_used_cuda_device(device_map)
|
||||
|
||||
|
@ -208,7 +226,12 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
|
||||
else:
|
||||
# custom_stream = torch.cuda.Stream()
|
||||
torch.cuda.set_device(torch_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(torch_device)
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.set_device(torch_device)
|
||||
else:
|
||||
RuntimeError("The device: {torch_device} is not available")
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
|
||||
# with torch.cuda.stream(custom_stream):
|
||||
logits=model(inputs_embeds=inputs_embeds,
|
||||
|
@ -216,10 +239,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False, use_cache=True)[0]
|
||||
if past_key_values != None:
|
||||
if past_key_values != None and isinstance(past_key_values, StaticCache):
|
||||
past_key_values.change_seq_length(1)
|
||||
for device in all_cuda_device:
|
||||
torch.cuda.synchronize(device)
|
||||
sync_all_device(all_cuda_device)
|
||||
#print(logits)
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
if generation_config.do_sample:
|
||||
|
@ -245,11 +267,19 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
|
||||
return logits
|
||||
|
||||
torch.cuda.set_device(torch_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(torch_device)
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.set_device(torch_device)
|
||||
else:
|
||||
RuntimeError("The device: {torch_device} is not available")
|
||||
with torch.no_grad():
|
||||
|
||||
stream = TextStreamer(tokenizer)
|
||||
if mode != 'long_context':
|
||||
if torch.xpu.is_available():
|
||||
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache
|
||||
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
|
||||
elif mode != 'long_context':
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue