fix no balance_serve import error

This commit is contained in:
qiyuxinlin 2025-04-22 02:11:18 +00:00
parent 03a65d6bea
commit f5287e908a

View file

@ -12,7 +12,10 @@ import torch.nn as nn
import transformers
from transformers import Cache, PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple
from ktransformers.server.balance_serve.settings import sched_ext
try:
from ktransformers.server.balance_serve.settings import sched_ext
except:
print("no balance_serve")
class StaticCache(transformers.StaticCache):
"""
Static Cache class to be used with `torch.compile(model)`.
@ -210,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self.v_caches = []
def load(self, inference_context: sched_ext.InferenceContext):
def load(self, inference_context: "sched_ext.InferenceContext"):
for i in range(self.config.num_hidden_layers):
self.k_caches.append(