mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
roll back ktransformers backend, add max_tokens, max_completion_tokens param
This commit is contained in:
parent
a1162eea01
commit
03a65d6bea
10 changed files with 144 additions and 161 deletions
|
@ -129,8 +129,14 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
if max_tokens is not None:
|
||||
max_completion_tokens = max_tokens
|
||||
if max_completion_tokens is None:
|
||||
max_new_tokens = self.args.max_new_tokens
|
||||
else:
|
||||
max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)
|
||||
if(input_ids_length >= self.args.cache_lens):
|
||||
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
||||
self.seq_length = input_ids_length
|
||||
|
@ -147,7 +153,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
if getattr(self, 'generated_ids', None) is None:
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
||||
input_ids.shape[-1] + max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
|
@ -174,7 +180,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
|
||||
expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens)
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
|
@ -222,16 +228,17 @@ class KTransformersInterface(TransformersInterface):
|
|||
MLAWrapperSingleton.reset_buffer()
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
||||
@property
|
||||
def active_cache_position(self):
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
return torch.tensor([self.seq_length - 1], device=device)
|
||||
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
async with self._infer_lock:
|
||||
async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
|
||||
async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens):
|
||||
yield v
|
||||
|
||||
# return this inference raw usage
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue