fix no balance_serve import error

This commit is contained in:
qiyuxinlin 2025-04-22 02:11:18 +00:00
parent 03a65d6bea
commit f5287e908a

View file

@ -12,7 +12,10 @@ import torch.nn as nn
import transformers import transformers
from transformers import Cache, PretrainedConfig from transformers import Cache, PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from ktransformers.server.balance_serve.settings import sched_ext try:
from ktransformers.server.balance_serve.settings import sched_ext
except:
print("no balance_serve")
class StaticCache(transformers.StaticCache): class StaticCache(transformers.StaticCache):
""" """
Static Cache class to be used with `torch.compile(model)`. Static Cache class to be used with `torch.compile(model)`.
@ -210,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self.v_caches = [] self.v_caches = []
def load(self, inference_context: sched_ext.InferenceContext): def load(self, inference_context: "sched_ext.InferenceContext"):
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
self.k_caches.append( self.k_caches.append(