mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-15 09:39:42 +00:00
Solve torch.backends.cuda.sdp_kernel()
is deprecated.
This commit is contained in:
parent
1548c99234
commit
f74c2d1d17
1 changed files with 2 additions and 1 deletions
|
@ -13,6 +13,7 @@ from transformers import (
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.server.schemas.base import ObjectID
|
from ktransformers.server.schemas.base import ObjectID
|
||||||
from ktransformers.server.utils.multi_timer import Profiler
|
from ktransformers.server.utils.multi_timer import Profiler
|
||||||
|
from torch.nn.attention import SDPBackend
|
||||||
import torch
|
import torch
|
||||||
import sys, os
|
import sys, os
|
||||||
from ..base import ThreadContext, BackendInterfaceBase
|
from ..base import ThreadContext, BackendInterfaceBase
|
||||||
|
@ -292,7 +293,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
def generate(self):
|
def generate(self):
|
||||||
self.profiler.set_counter("decode", 0)
|
self.profiler.set_counter("decode", 0)
|
||||||
for _ in range(1, self.args.max_new_tokens):
|
for _ in range(1, self.args.max_new_tokens):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
||||||
next_token = self.decode_one_tokens()
|
next_token = self.decode_one_tokens()
|
||||||
self.profiler.inc("decode")
|
self.profiler.inc("decode")
|
||||||
if next_token == self.tokenizer.eos_token_id:
|
if next_token == self.tokenizer.eos_token_id:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue