mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-08 13:39:48 +00:00
110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
'''
|
|
Date: 2024-11-14 12:23:45
|
|
LastEditors: Xie Weiyu ervinxie@qq.com
|
|
LastEditTime: 2024-11-25 08:59:23
|
|
'''
|
|
import logging
|
|
import torch
|
|
from torch import nn
|
|
from transformers import GenerationConfig
|
|
|
|
from flashinfer.sampling import (
|
|
min_p_sampling_from_probs,
|
|
top_k_renorm_probs,
|
|
top_k_top_p_sampling_from_logits,
|
|
top_p_renorm_probs,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class SamplingOptions():
|
|
# Batched sampling params
|
|
temperatures: torch.Tensor
|
|
top_ps: torch.Tensor
|
|
top_ks: torch.Tensor
|
|
min_ps: torch.Tensor
|
|
|
|
# All requests use greedy sampling
|
|
is_all_greedy: bool
|
|
|
|
# Dispatch in CUDA graph
|
|
need_min_p_sampling: bool
|
|
|
|
def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):
|
|
if pretrained_config is None and temperatures is None:
|
|
self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)
|
|
self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)
|
|
self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)
|
|
self.need_min_p_sampling = False
|
|
self.is_all_greedy = True
|
|
else:
|
|
if temperatures is not None:
|
|
self.temperatures = temperatures.unsqueeze(-1)
|
|
else:
|
|
self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)
|
|
|
|
if top_ps is not None:
|
|
self.top_ps = top_ps.unsqueeze(-1)
|
|
else:
|
|
self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)
|
|
self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)
|
|
self.need_min_p_sampling = False
|
|
self.is_all_greedy = False
|
|
|
|
class Sampler(nn.Module):
|
|
def __init__(self, device=torch.device('cuda')):
|
|
super().__init__()
|
|
self.device = device
|
|
|
|
def forward(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_config: SamplingOptions = None,
|
|
):
|
|
if sampling_config == None:
|
|
sampling_config = SamplingOptions()
|
|
|
|
# Ensure all tensors are on the same device
|
|
device = logits.device
|
|
logits = logits.contiguous().to(device)
|
|
sampling_config.temperatures = sampling_config.temperatures.to(device)
|
|
|
|
origin_logits = logits.clone()
|
|
if sampling_config.is_all_greedy:
|
|
# Use torch.argmax if all requests use greedy sampling
|
|
probs = logits
|
|
batch_next_token_ids = torch.argmax(logits, -1)
|
|
else:
|
|
# Post process logits
|
|
safe_temperatures = sampling_config.temperatures.masked_fill(sampling_config.temperatures == 0, 1.0)
|
|
logits.div_(safe_temperatures)
|
|
max_top_k_round, batch_size = 32, logits.shape[0]
|
|
if sampling_config.need_min_p_sampling:
|
|
probs = torch.softmax(logits, dim=-1)
|
|
logits = None
|
|
del logits
|
|
probs = top_k_renorm_probs(probs, sampling_config.top_ks)
|
|
probs = top_p_renorm_probs(probs, sampling_config.top_ps)
|
|
batch_next_token_ids = min_p_sampling_from_probs(
|
|
probs, sampling_config.min_ps
|
|
)
|
|
torch.cuda.synchronize()
|
|
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
|
if temperature_0_idx.numel() > 0:
|
|
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
|
else:
|
|
# TODO: use different kernel when don't need top_k or top_p
|
|
# @TODO get probs
|
|
probs = logits
|
|
batch_next_token_ids = top_k_top_p_sampling_from_logits(
|
|
logits,
|
|
sampling_config.top_ks,
|
|
sampling_config.top_ps,
|
|
filter_apply_order="joint",
|
|
)
|
|
torch.cuda.synchronize()
|
|
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
|
if temperature_0_idx.numel() > 0:
|
|
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
|
|
|
return batch_next_token_ids.to(torch.int32), probs
|