From 7adb7281f4a730f70c7dfc89394c15ddea9a96a2 Mon Sep 17 00:00:00 2001 From: Atream Date: Wed, 30 Apr 2025 03:37:43 +0000 Subject: [PATCH] fix-cache-lens --- ktransformers/server/args.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index b2a6769..1210e14 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -1,6 +1,7 @@ import argparse from ktransformers.server.backend.args import ConfigArgs, default_args from ktransformers.util.utils import get_free_ports +from transformers import AutoConfig class ArgumentParser: def __init__(self, cfg): @@ -138,7 +139,11 @@ class ArgumentParser: self.cfg.server_port = args.port self.cfg.user_force_think = args.force_think - args.gpu_memory_size = 4*1024*1024*1024 # TODO: set this to the actual GPU memory size + model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + if args.architectures == "Qwen3MoeForCausalLM" or args.architectures == "Qwen2MoeForCausalLM" : + args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim + else: + args.gpu_memory_size = args.cache_lens*2*576*61 self.cfg.gpu_memory_size = args.gpu_memory_size free_ports = get_free_ports(3, [args.port]) args.sched_port = free_ports[0]