mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-05 15:40:13 +00:00
263 lines
8.6 KiB
Python
263 lines
8.6 KiB
Python
from distutils.version import LooseVersion
|
|
|
|
from thop.vision.basic_hooks import *
|
|
from thop.rnn_hooks import *
|
|
from thop.utils import prGreen, prRed, prYellow
|
|
import sys, os
|
|
|
|
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
sys.path.insert(0, project_dir)
|
|
|
|
from ktransformers.util.utils import prefill_and_generate
|
|
|
|
# logger = logging.getLogger(__name__)
|
|
# logger.setLevel(logging.INFO)
|
|
|
|
|
|
if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
|
|
logging.warning(
|
|
"You are using an old version PyTorch {version}, which THOP does NOT support.".format(
|
|
version=torch.__version__
|
|
)
|
|
)
|
|
|
|
default_dtype = torch.float64
|
|
|
|
register_hooks = {
|
|
nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
|
|
nn.Conv1d: count_convNd,
|
|
nn.Conv2d: count_convNd,
|
|
nn.Conv3d: count_convNd,
|
|
nn.ConvTranspose1d: count_convNd,
|
|
nn.ConvTranspose2d: count_convNd,
|
|
nn.ConvTranspose3d: count_convNd,
|
|
nn.BatchNorm1d: count_normalization,
|
|
nn.BatchNorm2d: count_normalization,
|
|
nn.BatchNorm3d: count_normalization,
|
|
nn.LayerNorm: count_normalization,
|
|
nn.InstanceNorm1d: count_normalization,
|
|
nn.InstanceNorm2d: count_normalization,
|
|
nn.InstanceNorm3d: count_normalization,
|
|
nn.PReLU: count_prelu,
|
|
nn.Softmax: count_softmax,
|
|
nn.ReLU: zero_ops,
|
|
nn.ReLU6: zero_ops,
|
|
nn.LeakyReLU: count_relu,
|
|
nn.MaxPool1d: zero_ops,
|
|
nn.MaxPool2d: zero_ops,
|
|
nn.MaxPool3d: zero_ops,
|
|
nn.AdaptiveMaxPool1d: zero_ops,
|
|
nn.AdaptiveMaxPool2d: zero_ops,
|
|
nn.AdaptiveMaxPool3d: zero_ops,
|
|
nn.AvgPool1d: count_avgpool,
|
|
nn.AvgPool2d: count_avgpool,
|
|
nn.AvgPool3d: count_avgpool,
|
|
nn.AdaptiveAvgPool1d: count_adap_avgpool,
|
|
nn.AdaptiveAvgPool2d: count_adap_avgpool,
|
|
nn.AdaptiveAvgPool3d: count_adap_avgpool,
|
|
nn.Linear: count_linear,
|
|
nn.Dropout: zero_ops,
|
|
nn.Upsample: count_upsample,
|
|
nn.UpsamplingBilinear2d: count_upsample,
|
|
nn.UpsamplingNearest2d: count_upsample,
|
|
nn.RNNCell: count_rnn_cell,
|
|
nn.GRUCell: count_gru_cell,
|
|
nn.LSTMCell: count_lstm_cell,
|
|
nn.RNN: count_rnn,
|
|
nn.GRU: count_gru,
|
|
nn.LSTM: count_lstm,
|
|
nn.Sequential: zero_ops,
|
|
nn.PixelShuffle: zero_ops,
|
|
}
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"):
|
|
register_hooks.update({nn.SyncBatchNorm: count_normalization})
|
|
|
|
|
|
def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
|
|
handler_collection = []
|
|
types_collection = set()
|
|
if custom_ops is None:
|
|
custom_ops = {}
|
|
if report_missing:
|
|
verbose = True
|
|
|
|
def add_hooks(m):
|
|
if len(list(m.children())) > 0:
|
|
return
|
|
|
|
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
|
|
logging.warning(
|
|
"Either .total_ops or .total_params is already defined in %s. "
|
|
"Be careful, it might change your code's behavior." % str(m)
|
|
)
|
|
|
|
m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
|
|
m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
|
|
|
|
for p in m.parameters():
|
|
m.total_params += torch.DoubleTensor([p.numel()])
|
|
|
|
m_type = type(m)
|
|
|
|
fn = None
|
|
if (
|
|
m_type in custom_ops
|
|
): # if defined both op maps, use custom_ops to overwrite.
|
|
fn = custom_ops[m_type]
|
|
if m_type not in types_collection and verbose:
|
|
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
|
|
elif m_type in register_hooks:
|
|
fn = register_hooks[m_type]
|
|
if m_type not in types_collection and verbose:
|
|
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
|
|
else:
|
|
if m_type not in types_collection and report_missing:
|
|
prRed(
|
|
"[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
|
|
% m_type
|
|
)
|
|
|
|
if fn is not None:
|
|
handler = m.register_forward_hook(fn)
|
|
handler_collection.append(handler)
|
|
types_collection.add(m_type)
|
|
|
|
training = model.training
|
|
|
|
model.eval()
|
|
model.apply(add_hooks)
|
|
|
|
with torch.no_grad():
|
|
model(*inputs)
|
|
|
|
total_ops = 0
|
|
total_params = 0
|
|
for m in model.modules():
|
|
if len(list(m.children())) > 0: # skip for non-leaf module
|
|
continue
|
|
total_ops += m.total_ops
|
|
total_params += m.total_params
|
|
|
|
total_ops = total_ops.item()
|
|
total_params = total_params.item()
|
|
|
|
# reset model to original status
|
|
model.train(training)
|
|
for handler in handler_collection:
|
|
handler.remove()
|
|
|
|
# remove temporal buffers
|
|
for n, m in model.named_modules():
|
|
if len(list(m.children())) > 0:
|
|
continue
|
|
if "total_ops" in m._buffers:
|
|
m._buffers.pop("total_ops")
|
|
if "total_params" in m._buffers:
|
|
m._buffers.pop("total_params")
|
|
|
|
return total_ops, total_params
|
|
|
|
|
|
def custom_profile(
|
|
model: nn.Module,
|
|
inputs,
|
|
content,
|
|
tokenizer,
|
|
custom_ops=None,
|
|
verbose=True,
|
|
ret_layer_info=False,
|
|
report_missing=False,
|
|
):
|
|
handler_collection = {}
|
|
types_collection = set()
|
|
if custom_ops is None:
|
|
custom_ops = {}
|
|
if report_missing:
|
|
# overwrite `verbose` option when enable report_missing
|
|
verbose = True
|
|
|
|
def add_hooks(m: nn.Module):
|
|
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
|
|
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
|
|
|
|
# for p in m.parameters():
|
|
# m.total_params += torch.DoubleTensor([p.numel()])
|
|
|
|
m_type = type(m)
|
|
|
|
fn = None
|
|
if m_type in custom_ops:
|
|
# if defined both op maps, use custom_ops to overwrite.
|
|
fn = custom_ops[m_type]
|
|
if m_type not in types_collection and verbose:
|
|
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
|
|
elif m_type in register_hooks:
|
|
fn = register_hooks[m_type]
|
|
if m_type not in types_collection and verbose:
|
|
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
|
|
else:
|
|
if m_type not in types_collection and report_missing:
|
|
prRed(
|
|
"[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
|
|
% m_type
|
|
)
|
|
|
|
if fn is not None:
|
|
handler_collection[m] = (
|
|
m.register_forward_hook(fn),
|
|
m.register_forward_hook(count_parameters),
|
|
)
|
|
types_collection.add(m_type)
|
|
|
|
prev_training_status = model.training
|
|
|
|
model.eval()
|
|
model.apply(add_hooks)
|
|
|
|
messages = [{"role": "user", "content": content}]
|
|
input_tensor = tokenizer.apply_chat_template(
|
|
messages, add_generation_prompt=True, return_tensors="pt"
|
|
)
|
|
|
|
with torch.no_grad():
|
|
# model(*inputs)
|
|
# TODO: model.model to deal with the PeftModelForCaualLM temp
|
|
simple_prefill_and_generate_for_test(
|
|
model.model, tokenizer, input_tensor.cuda(), max_new_tokens=1000, use_cuda_graph=False, mode = 'normal', force_think = False, chunk_prefill_size = 8192,
|
|
)
|
|
|
|
def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
|
|
total_ops, total_params = module.total_ops.item(), 0
|
|
ret_dict = {}
|
|
for n, m in module.named_children():
|
|
# if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
|
|
# m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
|
|
# else:
|
|
# m_ops, m_params = m.total_ops, m.total_params
|
|
next_dict = {}
|
|
if m in handler_collection and not isinstance(
|
|
m, (nn.Sequential, nn.ModuleList)
|
|
):
|
|
m_ops, m_params = m.total_ops.item(), m.total_params.item()
|
|
else:
|
|
m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
|
|
ret_dict[n] = (m_ops, m_params, next_dict)
|
|
total_ops += m_ops
|
|
total_params += m_params
|
|
# print(prefix, module._get_name(), (total_ops, total_params))
|
|
return total_ops, total_params, ret_dict
|
|
|
|
total_ops, total_params, ret_dict = dfs_count(model)
|
|
|
|
# reset model to original status
|
|
model.train(prev_training_status)
|
|
for m, (op_handler, params_handler) in handler_collection.items():
|
|
op_handler.remove()
|
|
params_handler.remove()
|
|
m._buffers.pop("total_ops")
|
|
m._buffers.pop("total_params")
|
|
|
|
if ret_layer_info:
|
|
return total_ops, total_params, ret_dict
|
|
return total_ops, total_params
|