mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
server mix mla
This commit is contained in:
parent
038bc30888
commit
c176e516b5
2 changed files with 23 additions and 15 deletions
|
@ -15,7 +15,7 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
|||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
|
||||
|
||||
warm_uped = False
|
||||
class KTransformersThreadContext(TransformersThreadContext):
|
||||
pass
|
||||
|
||||
|
@ -73,11 +73,13 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
def decode_one_tokens(self):
|
||||
def decode_one_tokens(self, i):
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
if self.args.use_cuda_graph:
|
||||
global warm_uped
|
||||
if self.args.use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
|
@ -91,14 +93,14 @@ class KTransformersInterface(TransformersInterface):
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
if hasattr(self, "cuda_graph_runner"):
|
||||
logits = self.cuda_graph_runner(
|
||||
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||
)
|
||||
self.cache.change_seq_length(1)
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
if hasattr(self, "cuda_graph_runner"):
|
||||
logits = self.cuda_graph_runner(
|
||||
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||
)
|
||||
self.cache.change_seq_length(1)
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue