done support deepseekv3

This commit is contained in:
Azure 2025-02-04 15:53:38 +00:00
parent f748cd29f0
commit 907251c743
9 changed files with 1413 additions and 580 deletions

View file

@ -24,7 +24,7 @@ class KTransformersInterface(TransformersInterface):
self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2"
@ -99,7 +99,7 @@ class KTransformersInterface(TransformersInterface):
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model(
self.current_ids,
self.current_ids.to(torch_device),
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,