mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
support KExpertsMarlin backend
This commit is contained in:
parent
0262f954c7
commit
c4d9bc6670
5 changed files with 214 additions and 46 deletions
|
@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
self.args = args
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
|
@ -71,30 +71,31 @@ class KTransformersInterface(TransformersInterface):
|
|||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
def decode_one_tokens(self):
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
self.model,
|
||||
self.current_ids,
|
||||
self.active_cache_position.unsqueeze(0),
|
||||
self.active_cache_position,
|
||||
self.cache,
|
||||
main_device=torch_device,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
if self.args.use_cuda_graph:
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
self.model,
|
||||
self.current_ids,
|
||||
self.active_cache_position.unsqueeze(0),
|
||||
self.active_cache_position,
|
||||
self.cache,
|
||||
main_device=torch_device,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if hasattr(self, "cuda_graph_runner"):
|
||||
logits = self.cuda_graph_runner(
|
||||
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||
)
|
||||
self.cache.change_seq_length(1)
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
if hasattr(self, "cuda_graph_runner"):
|
||||
logits = self.cuda_graph_runner(
|
||||
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||
)
|
||||
self.cache.change_seq_length(1)
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue