mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
support deepseekv3; runable but have precition problem
This commit is contained in:
parent
de7e892f72
commit
476b1d8dc6
13 changed files with 2178 additions and 24 deletions
|
@ -46,17 +46,26 @@ class KTransformersInterface(TransformersInterface):
|
|||
)
|
||||
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}")
|
||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||
self.cache = StaticCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=args.batch_size,
|
||||
max_cache_len=args.cache_lens,
|
||||
device=device_map,
|
||||
device=self.device_map,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}")
|
||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
|
||||
try:
|
||||
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
except:
|
||||
gen_config = GenerationConfig(
|
||||
max_length=128,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True
|
||||
)
|
||||
self.model.generation_config = gen_config
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
@ -102,3 +111,63 @@ class KTransformersInterface(TransformersInterface):
|
|||
logits = logits[0, -1, :]
|
||||
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
self.ever_generated_ids.clear()
|
||||
former_seq_length = 0
|
||||
self.seq_length = input_ids_length
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
self.seq_length + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
mask = torch.ones((1, self.seq_length)).to(device)
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
attention_mask=mask,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
@property
|
||||
def active_cache_position(self):
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
return torch.tensor([self.seq_length - 1], device=device)
|
Loading…
Add table
Add a link
Reference in a new issue