mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support npu
This commit is contained in:
parent
dd0e41b3b8
commit
7d51a13c9b
34 changed files with 14004 additions and 5626 deletions
|
@ -1,4 +1,19 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel
|
||||
from ktransformers.util.utils import get_device, get_all_used_cuda_device
|
||||
from ktransformers.util import utils
|
||||
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
import os
|
||||
|
||||
|
||||
from typing import Optional, List
|
||||
import asyncio
|
||||
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
||||
|
@ -19,6 +34,9 @@ from typing import Optional
|
|||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
|
||||
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
||||
class KTransformersThreadContext(TransformersThreadContext):
|
||||
|
@ -26,8 +44,15 @@ class KTransformersThreadContext(TransformersThreadContext):
|
|||
|
||||
|
||||
class KTransformersInterface(TransformersInterface):
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
def __init__(self, args: ConfigArgs = default_args, input_args=None):
|
||||
if use_torch_npu:
|
||||
self.args = input_args
|
||||
self.local_rank, self.world_size = setup_model_parallel(tp=self.args.tp)
|
||||
if utils.CUR_DEVICE is None:
|
||||
utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}"
|
||||
self.args.device = utils.CUR_DEVICE
|
||||
else:
|
||||
self.args = args
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||
|
@ -47,7 +72,10 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
with torch.device("meta"):
|
||||
self.model = custom_models[config.architectures[0]](config)
|
||||
if default_args.optimize_config_path is None:
|
||||
|
||||
if use_torch_npu and input_args.optimize_config_path is not None:
|
||||
optimize_config_path = input_args.optimize_config_path
|
||||
elif default_args.optimize_config_path is None:
|
||||
optimize_config_path = default_optimize_rules[config.architectures[0]]
|
||||
else:
|
||||
optimize_config_path = args.optimize_config_path
|
||||
|
@ -60,7 +88,14 @@ class KTransformersInterface(TransformersInterface):
|
|||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
|
||||
if use_torch_npu:
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config, q4_gguf_path=input_args.q4_gguf_path)
|
||||
#提前absorbed
|
||||
get_absort_weight(self.model, config)
|
||||
self.model.eval()
|
||||
else:
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
self.model.generation_config = generation_config
|
||||
self.device_map = self.model.gguf_loader.tensor_device_map
|
||||
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
|
||||
|
@ -77,9 +112,92 @@ class KTransformersInterface(TransformersInterface):
|
|||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
if use_torch_npu:
|
||||
self.top_p = torch.tensor([[self.model.generation_config.top_p]], dtype=torch.float16, device=self.args.device)
|
||||
self.top_k = torch.tensor([[self.model.generation_config.top_k]], dtype=torch.int32, device=self.args.device)
|
||||
self.temperature = torch.tensor([[self.model.generation_config.temperature]], dtype=torch.float16, device=self.args.device)
|
||||
self.next_token_fake = torch.tensor([[1]], dtype=torch.int32, device=self.args.device)
|
||||
self.next_token_probs = torch.tensor([[1.0]], dtype=torch.float16, device=self.args.device)
|
||||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
|
||||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
def decode_logits_to_token(self, logits: torch.Tensor):
|
||||
if self.model.generation_config.do_sample:
|
||||
logits = logits / self.temperature
|
||||
torch.manual_seed(0)
|
||||
probs = logits.view(1, self.model.config.vocab_size)
|
||||
sm = nn.Softmax(dim=-1)
|
||||
probs = sm(probs).half().npu()
|
||||
next_token = self.next_token_fake
|
||||
torch_npu._npu_topk_topp_sampling(probs, self.top_k, self.top_p, next_token, self.next_token_probs)
|
||||
last = next_token.squeeze(-1)
|
||||
else:
|
||||
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
_, last = torch.topk(probs, k=1, dim=-1)
|
||||
last = last.item()
|
||||
self.ever_generated_ids.add(last)
|
||||
return last
|
||||
|
||||
def decode_one_tokens_npu(self):
|
||||
global warm_uped
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
torch.cuda.set_device(torch_device)
|
||||
if warm_uped and self.args.use_cuda_graph:
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner, check_runner
|
||||
if check_runner(self.args.device):
|
||||
npu_graph_runner = get_or_create_runner(self.args.device)
|
||||
npu_graph_runner.init(self.args.batch_size, self.seq_length)
|
||||
self.cuda_graph_runner = npu_graph_runner
|
||||
utils._USE_NPU_GRAPH = True
|
||||
self.cuda_graph_runner.capture(
|
||||
self.model,
|
||||
self.current_ids,
|
||||
self.active_cache_position.unsqueeze(0),
|
||||
self.active_cache_position,
|
||||
self.cache,
|
||||
main_device=self.args.device,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if hasattr(self, "cuda_graph_runner"):
|
||||
inputs_embeds = self.model.model.embed_tokens(self.current_ids.to("cpu")).to(self.args.device)
|
||||
logits = self.cuda_graph_runner(
|
||||
inputs_embeds, self.active_cache_position.unsqueeze(0), self.active_cache_position
|
||||
)
|
||||
self.cache.change_seq_length(1)
|
||||
torch.cuda.synchronize()
|
||||
logits = logits[0, -1, :]
|
||||
return self.decode_logits_to_token(logits)
|
||||
|
||||
if self.args.use_cuda_graph:
|
||||
warm_uped = True
|
||||
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
self.current_ids.to(torch_device),
|
||||
cache_position=self.active_cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||
self.cache.change_seq_length(1)
|
||||
logits = logits[0, -1, :]
|
||||
|
||||
return self.decode_logits_to_token(logits)
|
||||
|
||||
def decode_one_tokens(self):
|
||||
if use_torch_npu:
|
||||
return self.decode_one_tokens_npu()
|
||||
|
||||
global warm_uped
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
|
@ -127,9 +245,145 @@ class KTransformersInterface(TransformersInterface):
|
|||
return self.logits_to_token(logits)
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill_npu(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
if(input_ids_length >= self.args.cache_lens):
|
||||
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
||||
self.seq_length = input_ids_length
|
||||
return
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
device = "cuda:0" if device == "cuda" else device
|
||||
device = self.args.device
|
||||
if is_new:
|
||||
self.ever_generated_ids.clear()
|
||||
same_prefix = 0
|
||||
flat_input_ids = input_ids.flatten()
|
||||
|
||||
if getattr(self, 'generated_ids', None) is None:
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
self.seq_length = 1
|
||||
|
||||
# flat_prev_ids = self.generated_ids.flatten()
|
||||
# for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
# if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
# same_prefix += 1
|
||||
# else:
|
||||
# break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
self.cache.position[0] = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
||||
self.ever_generated_ids.clear()
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
else:
|
||||
logger.warning(f"seq_length bigger than cache_lens, killed")
|
||||
exit(0)
|
||||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
|
||||
self.cache.position[0] = self.seq_length + 1
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
|
||||
def chunk_prefill(input_ids, cache_position):
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
torch.cuda.set_device(device)
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.cache,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
return logits
|
||||
|
||||
logits = None
|
||||
def prefill_wrapper(prof=None):
|
||||
nonlocal logits
|
||||
chunk_start = 0
|
||||
while chunk_start < input_ids_length:
|
||||
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
|
||||
if self.cache != None:
|
||||
self.cache.cur_idx = cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
|
||||
chunk_start += self.args.chunk_size
|
||||
if prof is not None:
|
||||
prof.step()
|
||||
if prof is not None:
|
||||
prof.stop()
|
||||
if logits is None:
|
||||
raise ValueError('logits cannot be None')
|
||||
|
||||
|
||||
global WARM_UP_SKIP_CNT
|
||||
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
|
||||
if prof_prefill == "1":
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
|
||||
)
|
||||
with torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU
|
||||
],
|
||||
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof_lm_head"),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=False,
|
||||
with_flops=False,
|
||||
with_modules=False,
|
||||
experimental_config=experimental_config) as prof:
|
||||
prefill_wrapper(prof)
|
||||
else:
|
||||
prefill_wrapper()
|
||||
|
||||
if flashinfer_enabled:
|
||||
MLAWrapperSingleton.reset_buffer()
|
||||
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
|
||||
if use_torch_npu:
|
||||
return self.prefill_npu(self, input_ids, is_new, temperature, top_p, max_tokens, max_completion_tokens)
|
||||
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
if max_tokens is not None:
|
||||
max_completion_tokens = max_tokens
|
||||
|
@ -144,6 +398,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
device = "cuda:0" if device == "cuda" else device
|
||||
if use_torch_npu:
|
||||
device = self.args.device
|
||||
|
||||
if is_new:
|
||||
self.ever_generated_ids.clear()
|
||||
|
@ -159,16 +415,19 @@ class KTransformersInterface(TransformersInterface):
|
|||
)
|
||||
self.seq_length = 1
|
||||
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
if not use_torch_npu:
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
if use_torch_npu:
|
||||
self.cache.position[0] = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
@ -193,6 +452,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
|
||||
if use_torch_npu:
|
||||
self.cache.position[0] = self.seq_length + 1
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
if not (type(self) is TransformersInterface):
|
||||
|
@ -248,4 +509,18 @@ class KTransformersInterface(TransformersInterface):
|
|||
decode_time = self.profiler.get_timer_sec('decode'),
|
||||
prefill_count = self.profiler.get_counter('prefill'),
|
||||
decode_count = self.profiler.get_counter('decode'),
|
||||
)
|
||||
)
|
||||
|
||||
def sync_inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None) -> str:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
async def run_async():
|
||||
result = []
|
||||
async for chunk in self.inference(local_messages, thread_id, temperature, top_p):
|
||||
pass
|
||||
return ""
|
||||
return loop.run_until_complete(run_async())
|
||||
finally:
|
||||
loop.close()
|
|
@ -32,6 +32,20 @@ from ktransformers.server.config.log import logger
|
|||
from ..args import ConfigArgs, default_args
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
|
||||
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util import utils
|
||||
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
|
||||
class TextStreamer:
|
||||
|
||||
|
@ -191,11 +205,19 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
# input_ids = self.tokenizer.apply_chat_template(
|
||||
# new_messages, return_tensors="pt", add_generation_prompt=True
|
||||
# ).to(self.args.device)
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
|
||||
if not use_torch_npu:
|
||||
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
|
||||
# drop <think> token in chat template
|
||||
if input_str.endswith('<think>\n'):
|
||||
input_str = input_str[:-len('<think>\n')]
|
||||
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
|
||||
else:
|
||||
logger.debug(f"new_messages: {new_messages}")
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
new_messages, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
if (self.last_request_id is not None) and self.last_request_id == thread_id:
|
||||
x = self.generated_ids[:,:self.seq_length]
|
||||
y = input_ids[:,:self.seq_length]
|
||||
|
@ -212,6 +234,8 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
def append_new_tokens(self, new_tokens: int) -> Optional[str]:
|
||||
self.generated_ids[0, self.seq_length] = new_tokens
|
||||
self.seq_length += 1
|
||||
if use_torch_npu:
|
||||
self.cache.position[0] = self.seq_length
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
@staticmethod
|
||||
|
@ -273,14 +297,21 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
top_p = self.model.generation_config.top_p
|
||||
if top_p == 0:
|
||||
top_p = 0.0001
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=self.args.top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
|
||||
if use_torch_npu:
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, do_sample=True,
|
||||
top_p=top_p, temperature=temperature
|
||||
)
|
||||
else:
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=self.args.top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
self.inputs = inputs
|
||||
|
||||
self.logits_warper = self.tf_logits_warper(generation_config)
|
||||
|
@ -372,7 +403,10 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
||||
device = input_ids.device
|
||||
if use_torch_npu:
|
||||
device = self.args.device
|
||||
else:
|
||||
device = input_ids.device
|
||||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
|
@ -420,7 +454,12 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
else: # for's else, if output get max new tokens
|
||||
yield self.streamer.end(), None
|
||||
yield "", "length"
|
||||
|
||||
|
||||
if use_torch_npu and self.args.use_cuda_graph:
|
||||
utils._USE_NPU_GRAPH = False
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner
|
||||
npu_graph_runner = get_or_create_runner(self.args.device)
|
||||
npu_graph_runner.destroy()
|
||||
|
||||
|
||||
def check_is_new(self, thread_id: str):
|
||||
|
@ -436,7 +475,87 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.last_request_id = thread_id
|
||||
return True
|
||||
|
||||
|
||||
async def inference_npu(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
rank = torch.distributed.get_rank()
|
||||
tp_size = utils.get_tensor_parallel_size()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
#local_messages = local_messages[0]['content']
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
if tp_size == world_size and tp_size > 1:
|
||||
torch.distributed.barrier()
|
||||
input_size = torch.tensor([input_ids.size(1)], dtype=torch.int64, device=self.args.device)
|
||||
all_input_sizes = [torch.zeros_like(input_size) for _ in range(world_size)]
|
||||
dist.all_gather(all_input_sizes, input_size)
|
||||
|
||||
max_input_size = max([size.item() for size in all_input_sizes])
|
||||
padded_input_ids = torch.zeros(1, max_input_size, dtype=input_ids.dtype, device=self.args.device)
|
||||
padded_input_ids[0, :input_ids.size(1)] = input_ids[0]
|
||||
|
||||
all_padded_inputs = [torch.zeros_like(padded_input_ids) for _ in range(world_size)]
|
||||
dist.all_gather(all_padded_inputs, padded_input_ids)
|
||||
|
||||
original_size = all_input_sizes[0].item()
|
||||
input_ids = all_padded_inputs[0][:, :original_size]
|
||||
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]):
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
|
||||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
if tp_size == world_size and rank != 0:
|
||||
pass
|
||||
else:
|
||||
print(think, end="",flush=True)
|
||||
yield think, None
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t, None
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
if tp_size == world_size and rank != 0:
|
||||
pass
|
||||
else:
|
||||
print(t, end="",flush=True)
|
||||
yield t, finish_reason
|
||||
|
||||
if tp_size == world_size and rank != 0:
|
||||
pass
|
||||
else:
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
|
||||
if use_torch_npu:
|
||||
async for tok in self.inference_npu(local_messages, thread_id, temperature, top_p):
|
||||
yield tok
|
||||
return
|
||||
|
||||
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
if isinstance(local_messages, List):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue