mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-12 08:09:42 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
142
ktransformers/server/balance_serve/inference/config.py
Normal file
142
ktransformers/server/balance_serve/inference/config.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
'''
|
||||
Date: 2024-11-07 07:30:16
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-15 14:23:26
|
||||
'''
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
import yaml
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
class ModelConfig:
|
||||
vocab_size: int = 32000
|
||||
n_layer: int = 1
|
||||
n_head: int = 32
|
||||
dim: int = 4096
|
||||
intermediate_size: int = 18944
|
||||
n_local_heads: int = 8
|
||||
head_dim: int = 128
|
||||
rope_base: float = 1000000.0
|
||||
norm_eps: float = 1e-06
|
||||
rope_scaling: Optional[dict] = None
|
||||
rms_norm_eps: float = 1e-6
|
||||
hidden_act: str = "silu"
|
||||
model_path: str
|
||||
gguf_path: str
|
||||
optimize_rule_path: str
|
||||
speculative_rule_path: str
|
||||
|
||||
|
||||
# quantize config
|
||||
quant_algorithm: Optional[str] = None
|
||||
quant_group_size: Optional[int] = None
|
||||
quant_num_bits: Optional[int] = None
|
||||
|
||||
json_key_map = {
|
||||
"vocab_size": "vocab_size",
|
||||
"n_layer": "num_hidden_layers",
|
||||
"n_head": "num_attention_heads",
|
||||
"dim": "hidden_size",
|
||||
"intermediate_size": "intermediate_size",
|
||||
"n_local_heads": "num_key_value_heads",
|
||||
"rope_base": "rope_theta",
|
||||
"norm_eps": "norm_eps",
|
||||
"rms_norm_eps": "rms_norm_eps",
|
||||
"hidden_act": "hidden_act",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
self.model_path = config["model"]["model_path"]
|
||||
self.gguf_path = config["model"]["gguf_path"]
|
||||
self.optimize_rule_path = config["model"]["optimize_rule_path"]
|
||||
if "speculative_rule_path" in config["model"]:
|
||||
self.speculative_rule_path = config["model"]["speculative_rule_path"]
|
||||
self.speculative_gguf_path = config["model"]["speculative_gguf_path"]
|
||||
self.speculative_model_path = config["model"]["speculative_model_path"]
|
||||
self.quant_algorithm = config["model"]["quant"]["algorithm"]
|
||||
self.quant_group_size = config["model"]["quant"]["group_size"]
|
||||
self.quant_num_bits = config["model"]["quant"]["num_bits"]
|
||||
self.load_config()
|
||||
self.n_layer = config["model"]["n_layers"]
|
||||
|
||||
def load_config(self):
|
||||
config_file = f"{self.model_path}/config.json"
|
||||
try:
|
||||
with open(config_file, "r") as f:
|
||||
config_data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Configuration file not found at {config_file}")
|
||||
|
||||
for attr, json_key in self.json_key_map.items():
|
||||
if json_key in config_data:
|
||||
setattr(self, attr, config_data[json_key])
|
||||
else:
|
||||
setattr(self, attr, getattr(self, attr))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = config["parallel"]["pp"]
|
||||
self.tensor_parallel_size = config["parallel"]["tp"]
|
||||
self.disable_custom_all_reduce = config["parallel"]["disable_custom_all_reduce"]
|
||||
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
|
||||
class AttnConfig:
|
||||
page_size: int = 256
|
||||
block_num: int = 32
|
||||
max_batch_token : int = 256
|
||||
max_batch_size: int = 32
|
||||
|
||||
def __init__(self, config):
|
||||
self.page_size = config["attn"]["page_size"]
|
||||
self.block_num = config["attn"]["block_num"]
|
||||
self.max_batch_token = config["attn"]["max_batch_token"]
|
||||
self.max_batch_size = config["attn"]["max_batch_size"]
|
||||
|
||||
|
||||
class SamplerConfig():
|
||||
# Batched sampling params
|
||||
temperatures: float
|
||||
is_all_greedy: bool
|
||||
|
||||
def __init__(self, config):
|
||||
self.temperatures = config["sample"]["temperature"]
|
||||
self.is_all_greedy = True
|
||||
|
||||
|
||||
def load_yaml_config(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
model_config: ModelConfig
|
||||
parallel_config: ParallelConfig
|
||||
attn_config: AttnConfig
|
||||
sample_config: SamplerConfig
|
||||
config_file: str
|
||||
|
||||
def __init__(self, config_file):
|
||||
self.config_file = config_file
|
||||
config = load_yaml_config(config_file)
|
||||
self.model_config = ModelConfig(config)
|
||||
self.parallel_config = ParallelConfig(config)
|
||||
self.attn_config = AttnConfig(config)
|
||||
self.sample_config = SamplerConfig(config)
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue