support KExpertsMarlin backend

This commit is contained in:
Azure 2025-02-07 05:57:40 +00:00
parent 0262f954c7
commit c4d9bc6670
5 changed files with 214 additions and 46 deletions

View file

@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2"
@ -71,30 +71,31 @@ class KTransformersInterface(TransformersInterface):
self.streamer = TextStreamer(self.tokenizer)
def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=torch_device,
return_dict=False,
use_cache=True,
)
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
if self.args.use_cuda_graph:
if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=torch_device,
return_dict=False,
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device)