mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
done support deepseekv3
This commit is contained in:
parent
f748cd29f0
commit
907251c743
9 changed files with 1413 additions and 580 deletions
|
@ -24,7 +24,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
self.args = args
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
@ -99,7 +99,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
logits = self.model(
|
||||
self.current_ids,
|
||||
self.current_ids.to(torch_device),
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
attention_mask=mask,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue