mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
support npu
This commit is contained in:
parent
dd0e41b3b8
commit
7d51a13c9b
34 changed files with 14004 additions and 5626 deletions
|
@ -31,8 +31,35 @@ if not torch.xpu.is_available():
|
|||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
import socket
|
||||
|
||||
import os
|
||||
import re
|
||||
import torch.distributed as dist
|
||||
try:
|
||||
import torch_npu
|
||||
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
||||
|
||||
W8A8_ENABLE = False
|
||||
Q4_GGUF_LODER = None
|
||||
USE_NPU_GRAPH = None
|
||||
WARM_UP_SKIP_CNT = [1, 1]
|
||||
_USE_NPU_GRAPH = False
|
||||
_MAX_DECODE_PROFILE = 3
|
||||
CUR_DEVICE = None
|
||||
_MAX_CHUNK_SIZE = int(max(os.getenv("_MAX_CHUNK_SIZE", 4096), 512))
|
||||
|
||||
|
||||
def get_use_npu_graph():
|
||||
assert _USE_NPU_GRAPH is not None, "use npu graph is not setting"
|
||||
return _USE_NPU_GRAPH
|
||||
|
||||
|
||||
def get_free_ports(n: int, continue_prot: list):
|
||||
sockets = []
|
||||
ports = []
|
||||
|
@ -50,6 +77,10 @@ def get_free_ports(n: int, continue_prot: list):
|
|||
return ports
|
||||
|
||||
def get_compute_capability(device:torch.device = None):
|
||||
|
||||
if use_torch_npu:
|
||||
return 0
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if device is None:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
|
@ -97,9 +128,16 @@ def get_all_used_cuda_device(device_map:dict):
|
|||
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
|
||||
if "cpu" in all_device_list:
|
||||
all_device_list.remove("cpu")
|
||||
|
||||
if use_torch_npu:
|
||||
all_device_list = set([device.replace("cuda", "npu") for device in all_device_list])
|
||||
|
||||
all_device_list = list(all_device_list)
|
||||
return all_device_list
|
||||
|
||||
|
||||
|
||||
# TODO: support NPU
|
||||
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
|
||||
prefix = prefix.replace("orig_module.", "")
|
||||
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
||||
|
@ -109,6 +147,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
|||
key = prefix + name
|
||||
translated_key = key
|
||||
|
||||
|
||||
# TODO: Merge all loader.
|
||||
# I know this is ugly but lets do it for now.
|
||||
if isinstance(gguf_loader, SafeTensorLoader):
|
||||
|
@ -120,7 +159,13 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
|
|||
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
|
||||
|
||||
if use_torch_npu:
|
||||
device = "cpu" if "embd" in translated_key else CUR_DEVICE
|
||||
print(f"loading layer {translated_key} to {device}") if torch.distributed.get_rank() == 0 else None
|
||||
else:
|
||||
print(f"loading {translated_key} to {device}")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.xpu.is_available():
|
||||
|
@ -149,6 +194,8 @@ def sync_all_device(all_device_list):
|
|||
torch.cuda.synchronize(device)
|
||||
elif "xpu" in device.lower():
|
||||
torch.xpu.synchronize(device)
|
||||
elif use_torch_npu:
|
||||
torch_npu.synchronize(device)
|
||||
else:
|
||||
raise RuntimeError("The device {} is not available".format(device))
|
||||
|
||||
|
@ -228,20 +275,68 @@ def tf_logits_warper(generation_config):
|
|||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None, static_cache = None):
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
batch_size, seq_length = inputs.shape
|
||||
device_map = model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device('model.layers.0.self_attn', device_map)
|
||||
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
|
||||
|
||||
if use_torch_npu:
|
||||
vocabulary_size = model.config.vocab_size
|
||||
topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu()
|
||||
topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu()
|
||||
temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu()
|
||||
next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu()
|
||||
next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu()
|
||||
torch_device = CUR_DEVICE
|
||||
else:
|
||||
torch_device = get_device('model.layers.0.self_attn', device_map)
|
||||
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
|
||||
inputs = inputs.to(torch_device)
|
||||
all_cuda_device = get_all_used_cuda_device(device_map)
|
||||
|
||||
tokens = []
|
||||
|
||||
def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
|
||||
if cuda_graph_runner is None:
|
||||
use_cuda_graph = False
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device)
|
||||
if use_cuda_graph:
|
||||
logits = cuda_graph_runner(inputs_embeds, position_ids, cache_position)
|
||||
else:
|
||||
# custom_stream = torch.cuda.Stream()
|
||||
# torch.cuda.set_device(torch_device)
|
||||
torch_npu.npu.set_device(torch_device)
|
||||
# with torch.cuda.stream(custom_stream):
|
||||
logits=model(inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False, use_cache=True)[0]
|
||||
if past_key_values != None:
|
||||
past_key_values.change_seq_length(1)
|
||||
all_cuda_device = ['npu:' + str(index) for index in range(torch.distributed.get_world_size())]
|
||||
for device in all_cuda_device:
|
||||
# torch.cuda.synchronize(device)
|
||||
torch_npu.npu.synchronize(device)
|
||||
if generation_config.do_sample:
|
||||
logits = logits / temperature
|
||||
torch.manual_seed(0)
|
||||
probs = logits.view(batch_size, vocabulary_size)
|
||||
sm = nn.Softmax(dim=-1)
|
||||
probs = sm(probs).half().npu()
|
||||
next_token = next_token_fake
|
||||
torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs)
|
||||
next_token = next_token.squeeze(-1)
|
||||
else:
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||
return next_token
|
||||
|
||||
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
|
||||
if use_torch_npu:
|
||||
return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
|
||||
if cuda_graph_runner is None:
|
||||
use_cuda_graph = False
|
||||
if use_cuda_graph:
|
||||
|
@ -252,6 +347,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
torch.cuda.set_device(torch_device)
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.set_device(torch_device)
|
||||
elif use_torch_npu:
|
||||
torch_npu.set_device(torch_device)
|
||||
else:
|
||||
raise RuntimeError(f"The device: {torch_device} is not available")
|
||||
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
|
||||
|
@ -279,6 +376,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||
else:
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
|
||||
MLAWrapperSingleton.need_plan_all()
|
||||
|
@ -288,11 +386,88 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None):
|
||||
global warm_uped
|
||||
global _USE_NPU_GRAPH
|
||||
if use_cuda_graph:
|
||||
from ktransformers.util.npu_graph_runner import get_or_create_runner
|
||||
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
|
||||
npu_graph_runner.init(batch_size, seq_length)
|
||||
with torch_npu.npu.stream(npu_graph_runner.main_stream):
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
|
||||
torch.bfloat16)
|
||||
if use_cuda_graph and ((warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2)):
|
||||
warm_uped = True
|
||||
_USE_NPU_GRAPH = True
|
||||
npu_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
cuda_graph_runner = npu_graph_runner
|
||||
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids,
|
||||
cache_position, past_key_values, logits_warper, generation_config,
|
||||
use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
|
||||
next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
else:
|
||||
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
past_key_values.position[0] += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if prof is not None:
|
||||
prof.step()
|
||||
npu_graph_runner.destroy()
|
||||
_USE_NPU_GRAPH = False
|
||||
else:
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
|
||||
torch.bfloat16)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position,
|
||||
past_key_values, logits_warper, generation_config, use_cuda_graph).to(
|
||||
torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
|
||||
next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
else:
|
||||
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
past_key_values.position[0] += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if prof is not None:
|
||||
prof.step()
|
||||
if prof is not None:
|
||||
prof.stop()
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(torch_device)
|
||||
elif torch.xpu.is_available():
|
||||
torch.xpu.set_device(torch_device)
|
||||
elif use_torch_npu:
|
||||
torch_npu.set_device(torch_device)
|
||||
else:
|
||||
raise RuntimeError(f"The device: {torch_device} is not available")
|
||||
with torch.no_grad():
|
||||
|
@ -304,6 +479,16 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
|
||||
else:
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(None)
|
||||
elif use_torch_npu and static_cache:
|
||||
assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache'
|
||||
past_key_values = static_cache
|
||||
if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens:
|
||||
print('[WARN] current staticCache size exceeded, try create new staticCache...')
|
||||
past_key_values = StaticCache(
|
||||
config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype
|
||||
)
|
||||
else:
|
||||
past_key_values.reset()
|
||||
elif mode != 'long_context':
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
|
@ -320,19 +505,67 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
logits_warper = tf_logits_warper(generation_config)
|
||||
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||
if use_torch_npu:
|
||||
past_key_values.position[0] = seq_length + 1
|
||||
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
|
||||
start_time = time.time()
|
||||
|
||||
chunk_start = 0
|
||||
while chunk_start < seq_length:
|
||||
chunk_end = min(chunk_start + chunk_size, seq_length)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
|
||||
chunk_start += chunk_size
|
||||
logits = None
|
||||
|
||||
def prefill_wrapper(prof=None):
|
||||
nonlocal logits
|
||||
chunk_start = 0
|
||||
while chunk_start < seq_length:
|
||||
chunk_end = min(chunk_start + chunk_size, seq_length)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
|
||||
chunk_start += chunk_size
|
||||
if prof is not None:
|
||||
prof.step()
|
||||
if prof is not None:
|
||||
prof.stop()
|
||||
if logits is None:
|
||||
raise ValueError('logits cannot be None')
|
||||
|
||||
if use_torch_npu:
|
||||
global WARM_UP_SKIP_CNT
|
||||
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
|
||||
if prof_prefill == "1" and WARM_UP_SKIP_CNT[0] <= 0:
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
|
||||
)
|
||||
with torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU
|
||||
],
|
||||
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof"),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=False,
|
||||
with_flops=False,
|
||||
with_modules=False,
|
||||
experimental_config=experimental_config) as prof:
|
||||
prefill_wrapper(prof)
|
||||
else:
|
||||
prefill_wrapper()
|
||||
WARM_UP_SKIP_CNT[0] -= 1
|
||||
else:
|
||||
|
||||
chunk_start = 0
|
||||
while chunk_start < seq_length:
|
||||
chunk_end = min(chunk_start + chunk_size, seq_length)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
|
||||
chunk_start += chunk_size
|
||||
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
if generation_config.do_sample:
|
||||
|
@ -348,56 +581,106 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
|
||||
prefill_count = seq_length
|
||||
prefill_time = first_token_time
|
||||
if force_think:
|
||||
print("<think>")
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
|
||||
if force_think:
|
||||
print("<think>")
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
elif not use_torch_npu:
|
||||
if force_think:
|
||||
print("<think>")
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
|
||||
generated_ids[:, seq_length] = next_token
|
||||
tokens.append(int(next_token))
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
seq_length += 1
|
||||
if use_torch_npu:
|
||||
past_key_values.position += 1
|
||||
|
||||
cuda_graph_runner = None
|
||||
|
||||
start_time = time.time()
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
|
||||
global warm_uped
|
||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
|
||||
if not use_torch_npu:
|
||||
for i in range(1, max_new_tokens):
|
||||
if use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
|
||||
global warm_uped
|
||||
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
seq_length += 1
|
||||
|
||||
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
|
||||
print(stream.end(), end="", flush=True)
|
||||
break
|
||||
else:
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
else:
|
||||
prof_decode = os.environ["PROF_DECODE"] if "PROF_DECODE" in os.environ else "0"
|
||||
prof_ranks = os.environ["PROF_RANK"] if "PROF_RANK" in os.environ else "0"
|
||||
prof_ranks = [int(r.strip()) for r in prof_ranks.split(",")]
|
||||
if prof_decode == "1" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0:
|
||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
|
||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
|
||||
)
|
||||
with torch_npu.profiler.profile(
|
||||
activities=[
|
||||
torch_npu.profiler.ProfilerActivity.CPU,
|
||||
torch_npu.profiler.ProfilerActivity.NPU
|
||||
],
|
||||
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0),
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./decode_prof"),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=False,
|
||||
with_flops=False,
|
||||
with_modules=False,
|
||||
experimental_config=experimental_config) as prof:
|
||||
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof)
|
||||
else:
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
cache_position += 1
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length)
|
||||
WARM_UP_SKIP_CNT[1] -= 1
|
||||
|
||||
|
||||
total_time = time.time() - start_time
|
||||
tokens_generated = len(tokens)
|
||||
tokens_per_second = tokens_generated / total_time
|
||||
|
||||
print("")
|
||||
if not use_torch_npu:
|
||||
print("")
|
||||
|
||||
print(f"prompt eval count: {prefill_count} token(s)")
|
||||
print(f"prompt eval duration: {prefill_time}s")
|
||||
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
|
||||
print(f"eval count: {tokens_generated} token(s)")
|
||||
print(f"eval duration: {total_time}s")
|
||||
print(f"eval rate: {tokens_per_second} tokens/s")
|
||||
else:
|
||||
tp_size = get_tensor_parallel_size()
|
||||
if torch.distributed.get_rank() % tp_size == 0:
|
||||
rank = f"[rank:{torch.distributed.get_rank()}]"
|
||||
msg = f"\n{rank} Eval Time\n"
|
||||
msg += rank + f"prompt eval count: {prefill_count} token(s)\n"
|
||||
msg += rank + f"prompt eval duration: {prefill_time:.9f}s\n"
|
||||
msg += rank + f"prompt eval rate: {prefill_count/prefill_time:.9f} tokens/s\n"
|
||||
msg += rank + f"eval count: {tokens_generated} token(s)\n"
|
||||
msg += rank + f"eval duration: {total_time:.9f}s\n"
|
||||
msg += rank + f"eval rate: {tokens_per_second:.9f} tokens/s\n"
|
||||
print(msg)
|
||||
|
||||
print(f"prompt eval count: {prefill_count} token(s)")
|
||||
print(f"prompt eval duration: {prefill_time}s")
|
||||
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
|
||||
print(f"eval count: {tokens_generated} token(s)")
|
||||
print(f"eval duration: {total_time}s")
|
||||
print(f"eval rate: {tokens_per_second} tokens/s")
|
||||
|
||||
return tokens
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue