mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -18,9 +18,26 @@ from ktransformers.models.custom_cache import StaticCache
|
|||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.util.textstream import TextStreamer
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
import socket
|
||||
|
||||
warm_uped = False
|
||||
|
||||
def get_free_ports(n: int, continue_prot: list):
|
||||
sockets = []
|
||||
ports = []
|
||||
for _ in range(n):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind(("", 0))
|
||||
port = s.getsockname()[1]
|
||||
if port in continue_prot:
|
||||
s.close()
|
||||
continue
|
||||
ports.append(port)
|
||||
sockets.append(s)
|
||||
for s in sockets:
|
||||
s.close()
|
||||
return ports
|
||||
|
||||
def get_compute_capability(device:torch.device = None):
|
||||
if torch.cuda.is_available():
|
||||
if device is None:
|
||||
|
@ -110,7 +127,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
|||
module.load()
|
||||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||
mode = 'normal', force_think: bool = False, chunk_prefill_size = 16384, use_flashinfer_mla = False,
|
||||
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
@ -202,11 +219,11 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
|
||||
chunk_start = 0
|
||||
while chunk_start < seq_length:
|
||||
chunk_end = min(chunk_start + chunk_prefill_size, seq_length)
|
||||
chunk_end = min(chunk_start + chunk_size, seq_length)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
|
||||
chunk_start += chunk_prefill_size
|
||||
chunk_start += chunk_size
|
||||
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
if generation_config.do_sample:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue