mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-15 09:39:42 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -0,0 +1,13 @@
|
|||
from .orchestrator import BatchedPenalizerOrchestrator
|
||||
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
||||
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
||||
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
||||
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
||||
|
||||
__all__ = [
|
||||
"BatchedFrequencyPenalizer",
|
||||
"BatchedMinNewTokensPenalizer",
|
||||
"BatchedPresencePenalizer",
|
||||
"BatchedRepetitionPenalizer",
|
||||
"BatchedPenalizerOrchestrator",
|
||||
]
|
|
@ -0,0 +1,376 @@
|
|||
import abc
|
||||
import dataclasses
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _ReqLike:
|
||||
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BatchLike:
|
||||
reqs: typing.List[_ReqLike]
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs)
|
||||
|
||||
|
||||
class BatchedPenalizerOrchestrator:
|
||||
batch: _BatchLike
|
||||
device: str
|
||||
vocab_size: int
|
||||
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
batch: _BatchLike,
|
||||
device: str,
|
||||
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.batch = batch
|
||||
self.device = device
|
||||
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
pen_is_required = penalizer.prepare_if_required()
|
||||
is_required |= pen_is_required
|
||||
self.is_required = is_required
|
||||
|
||||
if self.is_required:
|
||||
self.cumulate_input_tokens(
|
||||
input_ids=[req.origin_input_ids for req in self.reqs()]
|
||||
)
|
||||
|
||||
def reqs(self):
|
||||
return self.batch.reqs
|
||||
|
||||
def batch_size(self):
|
||||
return self.batch.batch_size()
|
||||
|
||||
def cumulate_input_tokens(
|
||||
self,
|
||||
input_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Feed the input tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
|
||||
"""
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
||||
|
||||
def cumulate_output_tokens(
|
||||
self,
|
||||
output_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
Feed the output tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_output_tokens(output_ids=token_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the penalizers to the logits.
|
||||
Note that it may apply the penalizers in-place.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The logits to apply the penalizers to.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The logits after applying the penalizers.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
logits = penalizer.apply(logits)
|
||||
|
||||
return logits
|
||||
|
||||
def filter(
|
||||
self,
|
||||
indices_to_keep: typing.List[int],
|
||||
indices_tensor_to_keep: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Filter the penalizers based on the indices to keep in the batch.
|
||||
|
||||
Args:
|
||||
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
empty_indices = len(indices_to_keep) == 0
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
tmp_is_required = penalizer.is_required()
|
||||
is_required = is_required or tmp_is_required
|
||||
if not tmp_is_required or empty_indices:
|
||||
penalizer.teardown()
|
||||
else:
|
||||
# create tensor index only when it's needed
|
||||
if indices_tensor_to_keep is None:
|
||||
indices_tensor_to_keep = torch.tensor(
|
||||
indices_to_keep, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
penalizer.filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
self.is_required = is_required
|
||||
|
||||
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
||||
"""
|
||||
Merge the penalizers of another orchestrator into this one.
|
||||
|
||||
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
|
||||
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
|
||||
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
|
||||
|
||||
Args:
|
||||
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
||||
"""
|
||||
if not self.is_required and not their.is_required:
|
||||
return
|
||||
|
||||
self.is_required |= their.is_required
|
||||
for Penalizer, their_penalizer in their.penalizers.items():
|
||||
if Penalizer not in self.penalizers:
|
||||
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
||||
|
||||
self.penalizers[Penalizer].merge(their_penalizer)
|
||||
|
||||
|
||||
class _TokenIDs:
|
||||
"""
|
||||
A class that wraps token IDs to provide additional utility functions to penalizers.
|
||||
|
||||
Attributes:
|
||||
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
||||
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
|
||||
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
||||
"""
|
||||
|
||||
orchestrator: BatchedPenalizerOrchestrator
|
||||
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
|
||||
cached_counts: torch.Tensor = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: BatchedPenalizerOrchestrator,
|
||||
token_ids: typing.Union[
|
||||
typing.List[torch.Tensor], typing.List[typing.List[int]]
|
||||
],
|
||||
):
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
if not isinstance(token_ids[0], torch.Tensor):
|
||||
token_ids = [
|
||||
torch.tensor(
|
||||
data=ids, dtype=torch.int64, device=self.orchestrator.device
|
||||
)
|
||||
for ids in token_ids
|
||||
]
|
||||
|
||||
self.token_ids = token_ids
|
||||
|
||||
def occurrence_count(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The occurrence count tensor.
|
||||
"""
|
||||
if self.cached_counts is not None:
|
||||
return self.cached_counts
|
||||
|
||||
token_ids = self.token_ids
|
||||
|
||||
if isinstance(token_ids, torch.Tensor):
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
|
||||
# needs to be long to be used as index in scatter_add
|
||||
if token_ids.dtype != torch.int64:
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
|
||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=token_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_token_ids,
|
||||
src=torch.ones_like(padded_token_ids),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
|
||||
return self.cached_counts
|
||||
|
||||
|
||||
class _BatchedPenalizer(abc.ABC):
|
||||
"""
|
||||
An abstract class for a batched penalizer.
|
||||
"""
|
||||
|
||||
orchestrator: BatchedPenalizerOrchestrator
|
||||
_is_prepared: bool = False
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
def is_prepared(self) -> bool:
|
||||
return self._is_prepared
|
||||
|
||||
def is_required(self) -> bool:
|
||||
return self._is_required()
|
||||
|
||||
def prepare(self):
|
||||
if not self.is_prepared():
|
||||
self._prepare()
|
||||
self._is_prepared = True
|
||||
|
||||
def prepare_if_required(self):
|
||||
if self.is_required():
|
||||
self.prepare()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def teardown(self):
|
||||
if self.is_prepared():
|
||||
self._teardown()
|
||||
self._is_prepared = False
|
||||
|
||||
def cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._cumulate_input_tokens(input_ids=input_ids)
|
||||
|
||||
def cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._cumulate_output_tokens(output_ids=output_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_prepared():
|
||||
return logits
|
||||
|
||||
return self._apply(logits=logits)
|
||||
|
||||
def filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
|
||||
def merge(self, their: "_BatchedPenalizer"):
|
||||
if not self.is_prepared() and not their.is_prepared():
|
||||
return
|
||||
|
||||
self.prepare()
|
||||
their.prepare()
|
||||
self._merge(their)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _is_required(self) -> bool:
|
||||
"""
|
||||
Check if the penalizer is required to be prepared.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _prepare(self):
|
||||
"""
|
||||
Prepare the penalizer.
|
||||
Usually, this is where the penalizer initializes its tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _teardown(self):
|
||||
"""
|
||||
Tear down the penalizer.
|
||||
Usually, this is where the penalizer frees its tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
"""
|
||||
Cumulate the input tokens.
|
||||
Orchestrator will call this function to feed the input tokens to the penalizer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
"""
|
||||
Cumulate the output tokens.
|
||||
Orchestrator will call this function to feed the output tokens to the penalizer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply the penalizer to the logits.
|
||||
Penalizers can modify the logits in-place if needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _merge(self, their: "_BatchedPenalizer"):
|
||||
"""
|
||||
Merge the penalizer with another penalizer.
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,80 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Frequency penalizer penalizes tokens based on their frequency in the output.
|
||||
"""
|
||||
|
||||
frequency_penalties: torch.Tensor = None
|
||||
cumulated_frequency_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.frequency_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.frequency_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_frequency_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.frequency_penalties
|
||||
del self.cumulated_frequency_penalties
|
||||
|
||||
self.frequency_penalties = None
|
||||
self.cumulated_frequency_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
self.cumulated_frequency_penalties += (
|
||||
self.frequency_penalties * output_ids.occurrence_count()
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_frequency_penalties
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||
self.frequency_penalties = torch.cat(
|
||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||
)
|
||||
self.cumulated_frequency_penalties = torch.cat(
|
||||
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
||||
dim=0,
|
||||
)
|
|
@ -0,0 +1,108 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Min new tokens penalizer penalizes tokens based on the length of the output.
|
||||
"""
|
||||
|
||||
min_new_tokens: torch.Tensor = None
|
||||
stop_token_penalties: torch.Tensor = None
|
||||
len_output_tokens: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.min_new_tokens = torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=self.orchestrator.device,
|
||||
).unsqueeze_(1)
|
||||
|
||||
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=[
|
||||
torch.tensor(
|
||||
data=(
|
||||
list(
|
||||
(req.sampling_params.stop_token_ids or set())
|
||||
| (req.tokenizer.additional_stop_token_ids or set())
|
||||
| {req.tokenizer.eos_token_id}
|
||||
)
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
self.stop_token_penalties = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_stop_token_ids,
|
||||
src=torch.full_like(
|
||||
input=padded_stop_token_ids,
|
||||
dtype=torch.float32,
|
||||
fill_value=float("-inf"),
|
||||
device=self.orchestrator.device,
|
||||
),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
|
||||
self.len_output_tokens = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), 1),
|
||||
dtype=torch.int32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.min_new_tokens
|
||||
del self.stop_token_penalties
|
||||
del self.len_output_tokens
|
||||
|
||||
self.min_new_tokens = None
|
||||
self.stop_token_penalties = None
|
||||
self.len_output_tokens = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
self.len_output_tokens += 1
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
||||
logits[mask] += self.stop_token_penalties[mask]
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
||||
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
||||
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
||||
|
||||
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
||||
self.min_new_tokens = torch.cat(
|
||||
[self.min_new_tokens, their.min_new_tokens], dim=0
|
||||
)
|
||||
self.stop_token_penalties = torch.cat(
|
||||
[self.stop_token_penalties, their.stop_token_penalties], dim=0
|
||||
)
|
||||
self.len_output_tokens = torch.cat(
|
||||
[self.len_output_tokens, their.len_output_tokens], dim=0
|
||||
)
|
|
@ -0,0 +1,79 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Presence penalizer penalizes tokens based on their presence in the output.
|
||||
"""
|
||||
|
||||
presence_penalties: torch.Tensor = None
|
||||
cumulated_presence_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.presence_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.presence_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_presence_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.presence_penalties
|
||||
del self.cumulated_presence_penalties
|
||||
|
||||
self.presence_penalties = None
|
||||
self.cumulated_presence_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_presence_penalties
|
||||
return logits
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||
self.presence_penalties = torch.cat(
|
||||
[self.presence_penalties, their.presence_penalties], dim=0
|
||||
)
|
||||
self.cumulated_presence_penalties = torch.cat(
|
||||
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
||||
dim=0,
|
||||
)
|
|
@ -0,0 +1,83 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
from ..orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
||||
"""
|
||||
|
||||
repetition_penalties: torch.Tensor = None
|
||||
cumulated_repetition_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.repetition_penalty != 1.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[1.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.repetition_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_repetition_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
del self.repetition_penalties
|
||||
del self.cumulated_repetition_penalties
|
||||
|
||||
self.repetition_penalties = None
|
||||
self.cumulated_repetition_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
mask = input_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(
|
||||
logits > 0,
|
||||
logits / self.cumulated_repetition_penalties,
|
||||
logits * self.cumulated_repetition_penalties,
|
||||
)
|
||||
|
||||
def _filter(
|
||||
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
||||
):
|
||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
||||
self.repetition_penalties = torch.cat(
|
||||
[self.repetition_penalties, their.repetition_penalties], dim=0
|
||||
)
|
||||
self.cumulated_repetition_penalties = torch.cat(
|
||||
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
||||
dim=0,
|
||||
)
|
100
ktransformers/server/balance_serve/inference/sampling/sampler.py
Normal file
100
ktransformers/server/balance_serve/inference/sampling/sampler.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
'''
|
||||
Date: 2024-11-14 12:23:45
|
||||
LastEditors: Xie Weiyu ervinxie@qq.com
|
||||
LastEditTime: 2024-11-25 08:59:23
|
||||
'''
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from flashinfer.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_probs,
|
||||
top_k_top_p_sampling_from_logits,
|
||||
top_p_renorm_probs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SamplingOptions():
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor
|
||||
top_ps: torch.Tensor
|
||||
top_ks: torch.Tensor
|
||||
min_ps: torch.Tensor
|
||||
|
||||
# All requests use greedy sampling
|
||||
is_all_greedy: bool
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
need_min_p_sampling: bool
|
||||
|
||||
def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):
|
||||
if pretrained_config is None and temperatures is None:
|
||||
self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)
|
||||
self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)
|
||||
self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)
|
||||
self.need_min_p_sampling = False
|
||||
self.is_all_greedy = True
|
||||
else:
|
||||
if temperatures is not None:
|
||||
self.temperatures = temperatures.unsqueeze(-1)
|
||||
else:
|
||||
self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)
|
||||
|
||||
if top_ps is not None:
|
||||
self.top_ps = top_ps.unsqueeze(-1)
|
||||
else:
|
||||
self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)
|
||||
self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)
|
||||
self.need_min_p_sampling = False
|
||||
self.is_all_greedy = False
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_config: SamplingOptions = None,
|
||||
):
|
||||
if sampling_config == None:
|
||||
sampling_config = SamplingOptions()
|
||||
|
||||
logits = logits.contiguous()
|
||||
origin_logits = logits.clone()
|
||||
if sampling_config.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
probs = logits
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
else:
|
||||
# Post process logits
|
||||
logits.div_(sampling_config.temperatures)
|
||||
max_top_k_round, batch_size = 32, logits.shape[0]
|
||||
if sampling_config.need_min_p_sampling:
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
logits = None
|
||||
del logits
|
||||
probs = top_k_renorm_probs(probs, sampling_config.top_ks)
|
||||
probs = top_p_renorm_probs(probs, sampling_config.top_ps)
|
||||
batch_next_token_ids = min_p_sampling_from_probs(
|
||||
probs, sampling_config.min_ps
|
||||
)
|
||||
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||
else:
|
||||
# TODO: use different kernel when don't need top_k or top_p
|
||||
# @TODO get probs
|
||||
probs = logits
|
||||
batch_next_token_ids = top_k_top_p_sampling_from_logits(
|
||||
logits,
|
||||
sampling_config.top_ks,
|
||||
sampling_config.top_ps,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||
|
||||
return batch_next_token_ids.to(torch.int32), probs
|
Loading…
Add table
Add a link
Reference in a new issue