mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
support qwen3, dont speak human language
This commit is contained in:
parent
f3d842a0ca
commit
3f9bbf1181
30 changed files with 3696 additions and 290 deletions
|
@ -20,6 +20,7 @@ class ArgumentParser:
|
|||
parser.add_argument(
|
||||
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
|
||||
)
|
||||
parser.add_argument("--architectures", type=str, default=self.cfg.model_name)
|
||||
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
|
||||
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
|
||||
|
@ -93,6 +94,7 @@ class ArgumentParser:
|
|||
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
|
||||
parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)
|
||||
parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)
|
||||
# parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=False)
|
||||
|
||||
# web config
|
||||
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
|
||||
|
@ -137,7 +139,7 @@ class ArgumentParser:
|
|||
self.cfg.server_port = args.port
|
||||
self.cfg.user_force_think = args.force_think
|
||||
|
||||
args.gpu_memory_size = args.cache_lens*2*576*61
|
||||
args.gpu_memory_size = 4*1024*1024*1024 # TODO: set this to the actual GPU memory size
|
||||
self.cfg.gpu_memory_size = args.gpu_memory_size
|
||||
free_ports = get_free_ports(3, [args.port])
|
||||
args.sched_port = free_ports[0]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, AsyncIterator, List, Optional, Set
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
|
@ -22,6 +22,9 @@ from ktransformers.server.config.log import logger
|
|||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
|
@ -53,8 +56,10 @@ ktransformer_rules_dir = (
|
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||
)
|
||||
default_optimize_rules = {
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "Moonlight-16B-A3B-serve.yaml",
|
||||
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
||||
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
||||
}
|
||||
|
||||
|
||||
|
@ -105,7 +110,7 @@ class Engine:
|
|||
model_runner: ModelRunner
|
||||
sampler: Sampler
|
||||
query_manager: QueryManager
|
||||
cache: KDeepSeekV3Cache
|
||||
cache: KDeepSeekV3Cache | KGQACache
|
||||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
|
||||
self.args = args
|
||||
|
||||
|
@ -117,17 +122,32 @@ class Engine:
|
|||
self.device = self.args.device
|
||||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.updates = []
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
if args.model_name == "Qwen3Moe":
|
||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
else:
|
||||
assert False, f"model {args.model_name} not supported"
|
||||
|
||||
|
||||
self.gen_queue = generated_token_queue
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "DeepseekV2ForCausalLM":
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
self.model = KDeepseekV2ForCausalLM(config, self.cache)
|
||||
# print(self.block_num)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
|
||||
|
||||
context = zmq.Context()
|
||||
|
||||
|
@ -176,9 +196,12 @@ class Engine:
|
|||
|
||||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
#@TODO add config
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, 1024 ,args.max_batch_size, self.block_num) # TODO: 1024 is a magic number(max_batch_tokens)
|
||||
else:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
self.sampler = Sampler()
|
||||
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
|
||||
|
||||
|
@ -231,7 +254,7 @@ class Engine:
|
|||
|
||||
if self.batch is not None:
|
||||
self.model_runner.sync()
|
||||
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
|
||||
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms, {1000/self.model_runner.model_time:.3f} tokens/s")
|
||||
# if self.rank == 0:
|
||||
|
||||
generated_tokens, probs = self.sampling( self.model_runner.output)
|
||||
|
|
|
@ -281,4 +281,4 @@ class ForwardBatchOutput:
|
|||
self.generated_tokens_num = []
|
||||
self.top_ps = []
|
||||
self.temperatures = []
|
||||
pass
|
||||
self.num_batchs = 1
|
|
@ -27,6 +27,8 @@ from ktransformers.server.balance_serve.inference.forward_batch import ForwardBa
|
|||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
|
||||
|
@ -40,11 +42,11 @@ def deduplicate_and_sort(lst):
|
|||
class ModelRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
||||
|
||||
model: KDeepseekV3ForCausalLM
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM
|
||||
input: ForwardBatchInput | list[ForwardBatchInput]
|
||||
output: ForwardBatchOutput
|
||||
|
||||
def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256):
|
||||
def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256, block_num = 8):
|
||||
|
||||
self.stream = torch.cuda.Stream(device=device)
|
||||
# 先注释掉
|
||||
|
@ -58,120 +60,92 @@ class ModelRunner:
|
|||
self.use_cuda_graph = use_cuda_graph
|
||||
self.model_time = 0
|
||||
self.page_size = page_size
|
||||
self.block_num = block_num
|
||||
# GPU timing for model execution
|
||||
self.start_model_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_model_event = torch.cuda.Event(enable_timing=True)
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]
|
||||
self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
else:
|
||||
self.graphs = torch.cuda.CUDAGraph()
|
||||
self.page_idx_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
|
||||
self.page_offset_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
|
||||
|
||||
self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]
|
||||
self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
|
||||
self.num_mini_batches = num_mini_batches
|
||||
|
||||
self.max_chunk_size = max_chunk_size
|
||||
|
||||
self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
|
||||
self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
|
||||
|
||||
def model_attn_plan(self, batch, cuda_graph_idx=0):
|
||||
if isinstance(self.model, KDeepseekV3ForCausalLM):
|
||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
|
||||
head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads,
|
||||
page_size=self.model.cache.page_size, causal=True,
|
||||
q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx)
|
||||
else:
|
||||
assert False, "model type not supported"
|
||||
|
||||
|
||||
def warmup(self):
|
||||
|
||||
def capture_graphs(cuda_graph_idx=-1):
|
||||
if cuda_graph_idx != -1:
|
||||
with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):
|
||||
self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)
|
||||
self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()
|
||||
else:
|
||||
with torch.cuda.graph(self.graphs, pool=self.graph_memory_pool, stream=self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
self.graph_memory_pool = self.graphs.pool()
|
||||
def capture_graphs(cuda_graph_idx):
|
||||
with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):
|
||||
self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)
|
||||
self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()
|
||||
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
self.input = []
|
||||
self.features_buf = []
|
||||
self.outputs_buf = []
|
||||
self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
for i in range(len(self.cuda_graphs)):
|
||||
prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch
|
||||
self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens = self.cuda_graphs[i]))
|
||||
self.input = []
|
||||
self.features_buf = []
|
||||
self.outputs_buf = []
|
||||
self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
for i in range(len(self.cuda_graphs)):
|
||||
prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch
|
||||
self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens=self.cuda_graphs[i]))
|
||||
|
||||
self.features_buf.append(self.model.batch_embeddings(self.input[i]))
|
||||
batch_size = self.input[i].minibatch.q_indptr.size(0)-1
|
||||
num_tokens = self.features_buf[i][0].size(0)
|
||||
print("capturing cuda graph", batch_size, num_tokens)
|
||||
self.bsz_tensor_buf[0] = batch_size
|
||||
self.num_tokens_tensor_buf[0] = num_tokens
|
||||
self.features_buf.append(self.model.batch_embeddings(self.input[i]))
|
||||
batch_size = self.input[i].minibatch.q_indptr.size(0)-1
|
||||
num_tokens = self.features_buf[i][0].size(0)
|
||||
print("capturing cuda graph", batch_size, num_tokens)
|
||||
|
||||
self.model.flash_infer_attn_plan(self.input[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)
|
||||
|
||||
self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
|
||||
|
||||
self.outputs_buf.append(None)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for warm_up_iters in range(11):
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i])
|
||||
torch.cuda.synchronize()
|
||||
self.bsz_tensor_buf[0] = batch_size
|
||||
self.num_tokens_tensor_buf[0] = num_tokens
|
||||
|
||||
capture_graphs(i)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.graphs[i].replay()
|
||||
|
||||
self.sync(calc_time=False)
|
||||
print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.")
|
||||
else:
|
||||
self.input = ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches)
|
||||
|
||||
self.features_buf = self.model.batch_embeddings(self.input)
|
||||
batch_size = self.input.minibatch.q_indptr.size(0)-1
|
||||
num_tokens = self.features_buf[0].size(0)
|
||||
self.model_attn_plan(self.input[i], i)
|
||||
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
|
||||
|
||||
self.bsz_tensor_buf = torch.tensor([batch_size], dtype=torch.int32, device=self.device)
|
||||
self.num_tokens_tensor_buf = torch.tensor([num_tokens], dtype=torch.int32, device=self.device)
|
||||
self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
|
||||
|
||||
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
|
||||
|
||||
self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
|
||||
|
||||
self.outputs_buf.append(None)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for warm_up_iters in range(11):
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i], cuda_graph_idx=i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def capture_graphs():
|
||||
with torch.cuda.graph(self.graphs, stream=self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
# self.graph_memory_pool = self.graphs.pool()
|
||||
self.outputs_buf[i].num_batchs = batch_size
|
||||
|
||||
|
||||
capture_graphs()
|
||||
capture_graphs(i)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.graphs.replay()
|
||||
self.graphs[i].replay()
|
||||
|
||||
self.sync(calc_time=False)
|
||||
print("warmup finished.")
|
||||
print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.")
|
||||
|
||||
def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):
|
||||
with torch.cuda.stream(self.stream):
|
||||
|
@ -189,107 +163,54 @@ class ModelRunner:
|
|||
|
||||
|
||||
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
|
||||
cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))
|
||||
if cuda_graph_idx == len(self.cuda_graphs):
|
||||
assert False, "num_tokens is too large"
|
||||
else:
|
||||
cuda_graph_idx = -1
|
||||
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
|
||||
cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))
|
||||
if not self.use_cuda_graph:
|
||||
cuda_graph_idx = 0
|
||||
# if cuda_graph_idx == len(self.cuda_graphs):
|
||||
# assert False, "num_tokens is too large"
|
||||
|
||||
if self.use_cuda_graph:
|
||||
if cuda_graph_idx != -1:
|
||||
self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)
|
||||
else:
|
||||
self.input.fill(batch, query_manager, self.page_size)
|
||||
self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)
|
||||
else:
|
||||
self.input = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)
|
||||
|
||||
self.input = [ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)]
|
||||
|
||||
|
||||
if cuda_graph_idx != -1 and self.use_cuda_graph:
|
||||
if self.use_cuda_graph:
|
||||
self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)
|
||||
else:
|
||||
self.features = self.model.batch_embeddings(self.input, device=self.device)
|
||||
|
||||
self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)
|
||||
|
||||
|
||||
self.bsz_tensor_buf.copy_(batch_size)
|
||||
self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device))
|
||||
|
||||
if self.use_cuda_graph:
|
||||
if cuda_graph_idx != -1:
|
||||
self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)
|
||||
else:
|
||||
self.features_buf[0].copy_(self.features[0], non_blocking=True)
|
||||
"""
|
||||
if num_tokens_0 > 64:
|
||||
padded_num_tokens_0 = pad_num_tokens(num_tokens_0)
|
||||
self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0
|
||||
"""
|
||||
#self.input.forward_minibatchs[0].print()
|
||||
# print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches])
|
||||
# print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}")
|
||||
self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)
|
||||
|
||||
# self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors)
|
||||
|
||||
"""
|
||||
self.model_attn_plan(self.input[cuda_graph_idx], cuda_graph_idx)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if self.use_cuda_graph:
|
||||
print("before replay features_buf", self.features_buf[0])
|
||||
print("features_buf addr", self.features_buf[0].data_ptr())
|
||||
self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
|
||||
self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
|
||||
self.replay(cuda_graph_idx)
|
||||
self.output = ForwardBatchOutput()
|
||||
|
||||
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
|
||||
|
||||
|
||||
self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())
|
||||
else:
|
||||
print("before run features", self.features[0])
|
||||
"""
|
||||
if cuda_graph_idx != -1 and self.use_cuda_graph:
|
||||
self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if self.use_cuda_graph:
|
||||
self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
self.replay(cuda_graph_idx)
|
||||
self.output = ForwardBatchOutput()
|
||||
|
||||
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
|
||||
self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
|
||||
self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
|
||||
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
|
||||
self.end_model_event.record(self.stream)
|
||||
|
||||
self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())
|
||||
else:
|
||||
self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
|
||||
self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
|
||||
self.end_model_event.record(self.stream)
|
||||
else:
|
||||
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if self.use_cuda_graph:
|
||||
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
self.replay(cuda_graph_idx)
|
||||
self.output = ForwardBatchOutput()
|
||||
|
||||
self.output.top_ps.append(self.input.minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input.minibatch.temperatures)
|
||||
|
||||
self.output.logits.append(self.outputs_buf.logits[0][self.input.minibatch.logits_start].clone())
|
||||
else:
|
||||
self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
|
||||
self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start]
|
||||
self.output.top_ps.append(self.input.minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input.minibatch.temperatures)
|
||||
|
||||
self.end_model_event.record(self.stream)
|
||||
|
||||
if not self.use_cuda_graph:
|
||||
self.output.num_batchs = self.input.batch_size
|
||||
else:
|
||||
self.output.num_batchs = self.input[cuda_graph_idx].batch_size
|
||||
|
||||
|
||||
def replay(self, cuda_graph_idx=-1):
|
||||
|
|
|
@ -10,7 +10,7 @@ current_file_path = os.path.abspath(__file__)
|
|||
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
import pickle
|
||||
import argparse
|
||||
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings
|
||||
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe
|
||||
|
||||
|
||||
|
||||
|
@ -209,5 +209,10 @@ if __name__ == '__main__':
|
|||
args = parser.parse_args()
|
||||
with open(args.config, "rb") as f:
|
||||
main_args = pickle.load(f)
|
||||
settings = create_sched_settings(main_args)
|
||||
if main_args.architectures == "Qwen2MoeForCausalLM":
|
||||
settings = create_sched_settings_qwen2moe(main_args)
|
||||
elif main_args.architectures == "Qwen3MoeForCausalLM":
|
||||
settings = create_sched_settings_qwen3moe(main_args)
|
||||
else:
|
||||
settings = create_sched_settings(main_args)
|
||||
start_server(settings, main_args)
|
||||
|
|
|
@ -11,6 +11,8 @@ from time import sleep
|
|||
import sched_ext
|
||||
from transformers import AutoConfig
|
||||
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
|
||||
def create_sched_settings(args):
|
||||
default_sample_options = sched_ext.SampleOptions()
|
||||
model_name = os.path.basename(os.path.normpath(args.model_dir))
|
||||
|
@ -64,7 +66,111 @@ def create_sched_settings(args):
|
|||
return settings
|
||||
|
||||
|
||||
|
||||
def create_sched_settings_qwen2moe(args):
|
||||
default_sample_options = sched_ext.SampleOptions()
|
||||
model_name = os.path.basename(os.path.normpath(args.model_dir))
|
||||
input_model_settings = sched_ext.ModelSettings()
|
||||
input_model_settings.model_path = args.model_dir
|
||||
input_model_settings.params_count = int(0)
|
||||
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
input_model_settings.layer_count = model_config.num_hidden_layers
|
||||
input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"]
|
||||
input_model_settings.k_head_dim = 128
|
||||
input_model_settings.bytes_per_params = 2
|
||||
input_model_settings.bytes_per_kv_cache_element = 2
|
||||
settings = sched_ext.Settings()
|
||||
settings.model_name = model_name
|
||||
settings.quant_type = "BF16"
|
||||
settings.model_settings = input_model_settings
|
||||
settings.page_size = args.page_size
|
||||
settings.gpu_device_count = 1 # tp
|
||||
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
|
||||
# settings.gpu_memory_size = args.cache_lens*576*2
|
||||
settings.gpu_memory_size = args.gpu_memory_size
|
||||
settings.memory_utilization_percentage = args.utilization_percentage
|
||||
max_batch_size = args.max_batch_size
|
||||
chunk_size = args.chunk_size
|
||||
|
||||
max_decode_batch_size = max_batch_size - 2
|
||||
|
||||
settings.max_batch_size = max_batch_size
|
||||
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
|
||||
settings.sample_options = default_sample_options
|
||||
settings.sched_metrics_port = args.sched_metrics_port
|
||||
settings.gpu_only = args.memory_gpu_only
|
||||
settings.use_self_defined_head_dim = False
|
||||
settings.self_defined_head_dim = 576
|
||||
settings.full_kv_cache_on_each_gpu = True
|
||||
settings.k_cache_on = True
|
||||
settings.v_cache_on = True
|
||||
|
||||
settings.kvc2_root_path = '/mnt/data/persist-kvc'
|
||||
settings.kvc2_config_path = args.kvc2_config_dir
|
||||
settings.memory_pool_size_GB = args.cpu_memory_size_GB
|
||||
settings.evict_count = 40
|
||||
settings.kvc2_metrics_port = args.kvc2_metrics_port
|
||||
settings.load_from_disk = False
|
||||
settings.save_to_disk = True
|
||||
|
||||
|
||||
settings.strategy_name = args.sched_strategy
|
||||
|
||||
settings.auto_derive()
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
def create_sched_settings_qwen3moe(args):
|
||||
default_sample_options = sched_ext.SampleOptions()
|
||||
model_name = os.path.basename(os.path.normpath(args.model_dir))
|
||||
input_model_settings = sched_ext.ModelSettings()
|
||||
input_model_settings.model_path = args.model_dir
|
||||
input_model_settings.params_count = int(0)
|
||||
model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
input_model_settings.layer_count = model_config.num_hidden_layers
|
||||
input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"]
|
||||
input_model_settings.k_head_dim = 128
|
||||
input_model_settings.bytes_per_params = 2
|
||||
input_model_settings.bytes_per_kv_cache_element = 2
|
||||
settings = sched_ext.Settings()
|
||||
settings.model_name = model_name
|
||||
settings.quant_type = "BF16"
|
||||
settings.model_settings = input_model_settings
|
||||
settings.page_size = args.page_size
|
||||
settings.gpu_device_count = 1 # tp
|
||||
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
|
||||
# settings.gpu_memory_size = args.cache_lens*576*2
|
||||
settings.gpu_memory_size = args.gpu_memory_size
|
||||
settings.memory_utilization_percentage = args.utilization_percentage
|
||||
max_batch_size = args.max_batch_size
|
||||
chunk_size = args.chunk_size
|
||||
|
||||
max_decode_batch_size = max_batch_size - 2
|
||||
|
||||
settings.max_batch_size = max_batch_size
|
||||
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
|
||||
settings.sample_options = default_sample_options
|
||||
settings.sched_metrics_port = args.sched_metrics_port
|
||||
settings.gpu_only = args.memory_gpu_only
|
||||
settings.use_self_defined_head_dim = False
|
||||
settings.self_defined_head_dim = 576
|
||||
settings.full_kv_cache_on_each_gpu = True
|
||||
settings.k_cache_on = True
|
||||
settings.v_cache_on = True
|
||||
|
||||
settings.kvc2_root_path = '/mnt/data/persist-kvc'
|
||||
settings.kvc2_config_path = args.kvc2_config_dir
|
||||
settings.memory_pool_size_GB = args.cpu_memory_size_GB
|
||||
settings.evict_count = 40
|
||||
settings.kvc2_metrics_port = args.kvc2_metrics_port
|
||||
settings.load_from_disk = False
|
||||
settings.save_to_disk = True
|
||||
|
||||
|
||||
settings.strategy_name = args.sched_strategy
|
||||
|
||||
settings.auto_derive()
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -100,6 +100,7 @@ class Config(metaclass=Singleton):
|
|||
# to make sure it consistent with previous version
|
||||
self.model_path: str = self.model_dir
|
||||
self.model_name: str = self.model.get("name", "")
|
||||
self.architectures: str = self.model.get("name", "")
|
||||
self.model_device: str = self.model.get("device", "cuda:0")
|
||||
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
|
||||
self.use_cuda_graph = self.model.get("use_cuda_graph", True)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
torch >= 2.3.0
|
||||
transformers == 4.43.2
|
||||
transformers == 4.51.3
|
||||
fastapi >= 0.111.0
|
||||
langchain >= 0.2.0
|
||||
blessed >= 1.20.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue