kvcache-ai-ktransformers/ktransformers/util/utils.py
2025-07-22 10:58:25 +00:00

691 lines
32 KiB
Python

#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import torch
from torch import nn
import itertools
import time
import enum
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
MinPLogitsWarper,
TypicalLogitsWarper,
EpsilonLogitsWarper,
EtaLogitsWarper,
)
from ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, GGUFLoader, translate_name_to_gguf
from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
if not torch.xpu.is_available():
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
import socket
import os
import re
import torch.distributed as dist
try:
import torch_npu
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
use_torch_npu = torch_npu.npu.is_available()
except:
use_torch_npu = False
warm_uped = False
W8A8_ENABLE = False
Q4_GGUF_LODER = None
USE_NPU_GRAPH = None
WARM_UP_SKIP_CNT = [1, 1]
_USE_NPU_GRAPH = False
_MAX_DECODE_PROFILE = 3
CUR_DEVICE = None
_MAX_CHUNK_SIZE = int(max(os.getenv("_MAX_CHUNK_SIZE", 4096), 512))
def get_use_npu_graph():
assert _USE_NPU_GRAPH is not None, "use npu graph is not setting"
return _USE_NPU_GRAPH
def get_free_ports(n: int, continue_prot: list):
sockets = []
ports = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
if port in continue_prot:
s.close()
continue
ports.append(port)
sockets.append(s)
for s in sockets:
s.close()
return ports
def get_compute_capability(device:torch.device = None):
if use_torch_npu:
return 0
if torch.cuda.is_available():
if device is None:
num_gpus = torch.cuda.device_count()
min_compute_capability_major = 100
for gpu_id in range(num_gpus):
gpu_props = torch.cuda.get_device_properties(gpu_id)
min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
return min_compute_capability_major
else:
return torch.cuda.get_device_properties(device)
else:
return 0
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
if hasattr(cur_mod, s):
cur_mod = getattr(cur_mod, s)
else: # nn.ModuleList or nn.ModuleList
cur_mod=cur_mod[int(s)]
if hasattr(cur_mod, tokens[-1]):
setattr(cur_mod, tokens[-1], module)
else: # nn.ModuleList or nn.ModuleList
cur_mod[int(tokens[-1])] = module
def set_param(module: nn.Module, name: str, weights: torch.Tensor):
param=nn.parameter.Parameter(weights, requires_grad=False)
if isinstance(module, nn.Linear) and len(weights.shape)==1:
param.unsqueeze_(0)
setattr(module, name, param)
def get_device(gguf_module_key:str, device_map:dict):
if gguf_module_key in device_map:
return device_map[gguf_module_key]["generate_device"]
else:
return "cuda"
def get_all_used_cuda_device(device_map:dict):
all_device_list = set()
for key in device_map:
all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None
all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None
if "cpu" in all_device_list:
all_device_list.remove("cpu")
if use_torch_npu:
all_device_list = set([device.replace("cuda", "npu") for device in all_device_list])
all_device_list = list(all_device_list)
return all_device_list
# TODO: support NPU
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
translated_key = key
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if isinstance(gguf_loader, SafeTensorLoader):
load_dequantized_tensor = gguf_loader.load_dequantized_tensor
else:
load_dequantized_tensor = gguf_loader.load_gguf_tensor
tensor_file_map = gguf_loader.tensor_file_map
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
if use_torch_npu:
device = "cpu" if "embd" in translated_key else CUR_DEVICE
print(f"loading layer {translated_key} to {device}") if torch.distributed.get_rank() == 0 else None
else:
print(f"loading {translated_key} to {device}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key):
attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype)
attn_k_b = attn_k_b.transpose(1, 2).contiguous()
attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype)
kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)
kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()
set_param(module, name, kv_b_proj)
del attn_k_b
del attn_v_b
else:
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
else:
#print(load_config.tensor_file_map.keys())
raise Exception(f"can't find {translated_key} in GGUF file!")
def sync_all_device(all_device_list):
for device in all_device_list:
if "cuda" in device.lower():
torch.cuda.synchronize(device)
elif "xpu" in device.lower():
torch.xpu.synchronize(device)
elif use_torch_npu:
torch_npu.synchronize(device)
else:
raise RuntimeError("The device {} is not available".format(device))
torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"}
def xpu_fp16_model(config):
# This function is to check if we run this model on XPU with FP16 dtype
if not torch.xpu.is_available():
return False
if config.architectures[0] == "DeepseekV3ForCausalLM":
return True
if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096:
# Qwen3-30B seems have precision issue with FP16
# so we only use FP16 for Qwen3-235B now
return True
return False
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
#print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix, device=device)
for name, child in module._modules.items():
load_weights(child, gguf_loader, prefix+name+".", device=device)
else:
module.load()
def tf_logits_warper(generation_config):
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# instantiate warpers list
warpers = LogitsProcessorList()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.min_p is not None:
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
warpers.append(
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
warpers.append(
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append(
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None, static_cache = None):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape
device_map = model.gguf_loader.tensor_device_map
if use_torch_npu:
vocabulary_size = model.config.vocab_size
topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu()
topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu()
temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu()
next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu()
next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu()
torch_device = CUR_DEVICE
else:
torch_device = get_device('model.layers.0.self_attn', device_map)
torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device
inputs = inputs.to(torch_device)
all_cuda_device = get_all_used_cuda_device(device_map)
tokens = []
def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if cuda_graph_runner is None:
use_cuda_graph = False
inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device)
if use_cuda_graph:
logits = cuda_graph_runner(inputs_embeds, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
# torch.cuda.set_device(torch_device)
torch_npu.npu.set_device(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
if past_key_values != None:
past_key_values.change_seq_length(1)
all_cuda_device = ['npu:' + str(index) for index in range(torch.distributed.get_world_size())]
for device in all_cuda_device:
# torch.cuda.synchronize(device)
torch_npu.npu.synchronize(device)
if generation_config.do_sample:
logits = logits / temperature
torch.manual_seed(0)
probs = logits.view(batch_size, vocabulary_size)
sm = nn.Softmax(dim=-1)
probs = sm(probs).half().npu()
next_token = next_token_fake
torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs)
next_token = next_token.squeeze(-1)
else:
next_token_scores = logits_warper(inputs, logits[:, -1, :])
next_token = torch.argmax(next_token_scores, dim=-1)
return next_token
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
if use_torch_npu:
return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)
if cuda_graph_runner is None:
use_cuda_graph = False
if use_cuda_graph:
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
else:
# custom_stream = torch.cuda.Stream()
if torch.cuda.is_available():
torch.cuda.set_device(torch_device)
elif torch.xpu.is_available():
torch.xpu.set_device(torch_device)
elif use_torch_npu:
torch_npu.set_device(torch_device)
else:
raise RuntimeError(f"The device: {torch_device} is not available")
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device)
# with torch.cuda.stream(custom_stream):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
if past_key_values != None and isinstance(past_key_values, StaticCache):
past_key_values.change_seq_length(1)
sync_all_device(all_cuda_device)
#print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_token = torch.argmax(next_token_scores, dim=-1)
return next_token
# TODO: use CUDA Graph for chunk prefill, may get small improvement
def chunk_prefill(inputs, cache_position, past_key_values):
if mode == "long_context":
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if use_flashinfer_mla:
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
MLAWrapperSingleton.need_plan_all()
logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
return logits
def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None):
global warm_uped
global _USE_NPU_GRAPH
if use_cuda_graph:
from ktransformers.util.npu_graph_runner import get_or_create_runner
npu_graph_runner = get_or_create_runner(CUR_DEVICE)
npu_graph_runner.init(batch_size, seq_length)
with torch_npu.npu.stream(npu_graph_runner.main_stream):
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
torch.bfloat16)
if use_cuda_graph and ((warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2)):
warm_uped = True
_USE_NPU_GRAPH = True
npu_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
cuda_graph_runner = npu_graph_runner
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids,
cache_position, past_key_values, logits_warper, generation_config,
use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
past_key_values.position[0] += 1
position_ids = cache_position.unsqueeze(0)
if prof is not None:
prof.step()
npu_graph_runner.destroy()
_USE_NPU_GRAPH = False
else:
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None, None, None, position_ids.squeeze(1) + 1, None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16,
torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position,
past_key_values, logits_warper, generation_config, use_cuda_graph).to(
torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(
next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
past_key_values.position[0] += 1
position_ids = cache_position.unsqueeze(0)
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if torch.cuda.is_available():
torch.cuda.set_device(torch_device)
elif torch.xpu.is_available():
torch.xpu.set_device(torch_device)
elif use_torch_npu:
torch_npu.set_device(torch_device)
else:
raise RuntimeError(f"The device: {torch_device} is not available")
with torch.no_grad():
stream = TextStreamer(tokenizer)
if torch.xpu.is_available():
from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache
if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)
else:
past_key_values = DynamicNormalCache.from_legacy_cache(None)
elif use_torch_npu and static_cache:
assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache'
past_key_values = static_cache
if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens:
print('[WARN] current staticCache size exceeded, try create new staticCache...')
past_key_values = StaticCache(
config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype
)
else:
past_key_values.reset()
elif mode != 'long_context':
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
)
else:
past_key_values = None
generation_config, model_kwargs = model._prepare_generation_config(
None, do_sample=True
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
)
logits_warper = tf_logits_warper(generation_config)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
if use_torch_npu:
past_key_values.position[0] = seq_length + 1
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
start_time = time.time()
logits = None
def prefill_wrapper(prof=None):
nonlocal logits
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
if prof is not None:
prof.step()
if prof is not None:
prof.stop()
if logits is None:
raise ValueError('logits cannot be None')
if use_torch_npu:
global WARM_UP_SKIP_CNT
prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0"
if prof_prefill == "1" and WARM_UP_SKIP_CNT[0] <= 0:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
prefill_wrapper(prof)
else:
prefill_wrapper()
WARM_UP_SKIP_CNT[0] -= 1
else:
chunk_start = 0
while chunk_start < seq_length:
chunk_end = min(chunk_start + chunk_size, seq_length)
if past_key_values != None:
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
chunk_start += chunk_size
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_token = torch.argmax(next_token_scores, dim=-1)
first_token_time = time.time() - start_time
if use_flashinfer_mla:
MLAWrapperSingleton.reset_buffer()
prefill_count = seq_length
prefill_time = first_token_time
if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0:
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
elif not use_torch_npu:
if force_think:
print("<think>")
print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token
tokens.append(int(next_token))
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
position_ids = cache_position.unsqueeze(0)
seq_length += 1
if use_torch_npu:
past_key_values.position += 1
cuda_graph_runner = None
start_time = time.time()
if not use_torch_npu:
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':
print(stream.end(), end="", flush=True)
break
else:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
position_ids = cache_position.unsqueeze(0)
else:
prof_decode = os.environ["PROF_DECODE"] if "PROF_DECODE" in os.environ else "0"
prof_ranks = os.environ["PROF_RANK"] if "PROF_RANK" in os.environ else "0"
prof_ranks = [int(r.strip()) for r in prof_ranks.split(",")]
if prof_decode == "1" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0:
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./decode_prof"),
record_shapes=True,
profile_memory=True,
with_stack=False,
with_flops=False,
with_modules=False,
experimental_config=experimental_config) as prof:
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof)
else:
decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length)
WARM_UP_SKIP_CNT[1] -= 1
total_time = time.time() - start_time
tokens_generated = len(tokens)
tokens_per_second = tokens_generated / total_time
if not use_torch_npu:
print("")
print(f"prompt eval count: {prefill_count} token(s)")
print(f"prompt eval duration: {prefill_time}s")
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
print(f"eval count: {tokens_generated} token(s)")
print(f"eval duration: {total_time}s")
print(f"eval rate: {tokens_per_second} tokens/s")
else:
tp_size = get_tensor_parallel_size()
if torch.distributed.get_rank() % tp_size == 0:
rank = f"[rank:{torch.distributed.get_rank()}]"
msg = f"\n{rank} Eval Time\n"
msg += rank + f"prompt eval count: {prefill_count} token(s)\n"
msg += rank + f"prompt eval duration: {prefill_time:.9f}s\n"
msg += rank + f"prompt eval rate: {prefill_count/prefill_time:.9f} tokens/s\n"
msg += rank + f"eval count: {tokens_generated} token(s)\n"
msg += rank + f"eval duration: {total_time:.9f}s\n"
msg += rank + f"eval rate: {tokens_per_second:.9f} tokens/s\n"
print(msg)
return tokens
class InferenceState(enum.Enum):
UNLOAD = 0
PREFILL = 1
GENERATE = 2
RESTORE = 3