mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -359,3 +359,56 @@ class DynamicNTKScalingRotaryEmbedding(
|
|||
self.orig_module.rope_type,
|
||||
self.orig_module.config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class RotaryEmbeddingV4(BaseInjectedModule):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
def load(self):
|
||||
self._init(
|
||||
dim=self.config.qk_rope_head_dim,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
base=self.config.rope_theta,
|
||||
device=self.device,
|
||||
)
|
||||
def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
Loading…
Add table
Add a link
Reference in a new issue