This commit is contained in:
Jesse 2025-08-05 15:24:17 +08:00 committed by GitHub
commit e204a0bb6b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 11 deletions

View file

@ -52,8 +52,9 @@ class SamplingOptions():
self.is_all_greedy = False
class Sampler(nn.Module):
def __init__(self):
def __init__(self, device=torch.device('cuda')):
super().__init__()
self.device = device
def forward(
self,
@ -63,7 +64,11 @@ class Sampler(nn.Module):
if sampling_config == None:
sampling_config = SamplingOptions()
logits = logits.contiguous()
# Ensure all tensors are on the same device
device = logits.device
logits = logits.contiguous().to(device)
sampling_config.temperatures = sampling_config.temperatures.to(device)
origin_logits = logits.clone()
if sampling_config.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
@ -71,7 +76,8 @@ class Sampler(nn.Module):
batch_next_token_ids = torch.argmax(logits, -1)
else:
# Post process logits
logits.div_(sampling_config.temperatures)
safe_temperatures = sampling_config.temperatures.masked_fill(sampling_config.temperatures == 0, 1.0)
logits.div_(safe_temperatures)
max_top_k_round, batch_size = 32, logits.shape[0]
if sampling_config.need_min_p_sampling:
probs = torch.softmax(logits, dim=-1)
@ -82,8 +88,10 @@ class Sampler(nn.Module):
batch_next_token_ids = min_p_sampling_from_probs(
probs, sampling_config.min_ps
)
torch.cuda.synchronize()
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
if temperature_0_idx.numel() > 0:
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
else:
# TODO: use different kernel when don't need top_k or top_p
# @TODO get probs
@ -94,7 +102,9 @@ class Sampler(nn.Module):
sampling_config.top_ps,
filter_apply_order="joint",
)
torch.cuda.synchronize()
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
if temperature_0_idx.numel() > 0:
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
return batch_next_token_ids.to(torch.int32), probs
return batch_next_token_ids.to(torch.int32), probs