kvcache-ai-ktransformers/KT-SFT/ktransformers/sft/flops_utils/custom_profile.py
2025-11-04 04:24:30 +00:00

263 lines
8.6 KiB
Python

from distutils.version import LooseVersion
from thop.vision.basic_hooks import *
from thop.rnn_hooks import *
from thop.utils import prGreen, prRed, prYellow
import sys, os
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, project_dir)
from ktransformers.util.utils import prefill_and_generate
# logger = logging.getLogger(__name__)
# logger.setLevel(logging.INFO)
if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
logging.warning(
"You are using an old version PyTorch {version}, which THOP does NOT support.".format(
version=torch.__version__
)
)
default_dtype = torch.float64
register_hooks = {
nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
nn.Conv1d: count_convNd,
nn.Conv2d: count_convNd,
nn.Conv3d: count_convNd,
nn.ConvTranspose1d: count_convNd,
nn.ConvTranspose2d: count_convNd,
nn.ConvTranspose3d: count_convNd,
nn.BatchNorm1d: count_normalization,
nn.BatchNorm2d: count_normalization,
nn.BatchNorm3d: count_normalization,
nn.LayerNorm: count_normalization,
nn.InstanceNorm1d: count_normalization,
nn.InstanceNorm2d: count_normalization,
nn.InstanceNorm3d: count_normalization,
nn.PReLU: count_prelu,
nn.Softmax: count_softmax,
nn.ReLU: zero_ops,
nn.ReLU6: zero_ops,
nn.LeakyReLU: count_relu,
nn.MaxPool1d: zero_ops,
nn.MaxPool2d: zero_ops,
nn.MaxPool3d: zero_ops,
nn.AdaptiveMaxPool1d: zero_ops,
nn.AdaptiveMaxPool2d: zero_ops,
nn.AdaptiveMaxPool3d: zero_ops,
nn.AvgPool1d: count_avgpool,
nn.AvgPool2d: count_avgpool,
nn.AvgPool3d: count_avgpool,
nn.AdaptiveAvgPool1d: count_adap_avgpool,
nn.AdaptiveAvgPool2d: count_adap_avgpool,
nn.AdaptiveAvgPool3d: count_adap_avgpool,
nn.Linear: count_linear,
nn.Dropout: zero_ops,
nn.Upsample: count_upsample,
nn.UpsamplingBilinear2d: count_upsample,
nn.UpsamplingNearest2d: count_upsample,
nn.RNNCell: count_rnn_cell,
nn.GRUCell: count_gru_cell,
nn.LSTMCell: count_lstm_cell,
nn.RNN: count_rnn,
nn.GRU: count_gru,
nn.LSTM: count_lstm,
nn.Sequential: zero_ops,
nn.PixelShuffle: zero_ops,
}
if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"):
register_hooks.update({nn.SyncBatchNorm: count_normalization})
def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
handler_collection = []
types_collection = set()
if custom_ops is None:
custom_ops = {}
if report_missing:
verbose = True
def add_hooks(m):
if len(list(m.children())) > 0:
return
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
logging.warning(
"Either .total_ops or .total_params is already defined in %s. "
"Be careful, it might change your code's behavior." % str(m)
)
m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
for p in m.parameters():
m.total_params += torch.DoubleTensor([p.numel()])
m_type = type(m)
fn = None
if (
m_type in custom_ops
): # if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
elif m_type in register_hooks:
fn = register_hooks[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
else:
if m_type not in types_collection and report_missing:
prRed(
"[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
% m_type
)
if fn is not None:
handler = m.register_forward_hook(fn)
handler_collection.append(handler)
types_collection.add(m_type)
training = model.training
model.eval()
model.apply(add_hooks)
with torch.no_grad():
model(*inputs)
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params
total_ops = total_ops.item()
total_params = total_params.item()
# reset model to original status
model.train(training)
for handler in handler_collection:
handler.remove()
# remove temporal buffers
for n, m in model.named_modules():
if len(list(m.children())) > 0:
continue
if "total_ops" in m._buffers:
m._buffers.pop("total_ops")
if "total_params" in m._buffers:
m._buffers.pop("total_params")
return total_ops, total_params
def custom_profile(
model: nn.Module,
inputs,
content,
tokenizer,
custom_ops=None,
verbose=True,
ret_layer_info=False,
report_missing=False,
):
handler_collection = {}
types_collection = set()
if custom_ops is None:
custom_ops = {}
if report_missing:
# overwrite `verbose` option when enable report_missing
verbose = True
def add_hooks(m: nn.Module):
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
# for p in m.parameters():
# m.total_params += torch.DoubleTensor([p.numel()])
m_type = type(m)
fn = None
if m_type in custom_ops:
# if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
elif m_type in register_hooks:
fn = register_hooks[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
else:
if m_type not in types_collection and report_missing:
prRed(
"[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
% m_type
)
if fn is not None:
handler_collection[m] = (
m.register_forward_hook(fn),
m.register_forward_hook(count_parameters),
)
types_collection.add(m_type)
prev_training_status = model.training
model.eval()
model.apply(add_hooks)
messages = [{"role": "user", "content": content}]
input_tensor = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
with torch.no_grad():
# model(*inputs)
# TODO: model.model to deal with the PeftModelForCaualLM temp
simple_prefill_and_generate_for_test(
model.model, tokenizer, input_tensor.cuda(), max_new_tokens=1000, use_cuda_graph=False, mode = 'normal', force_think = False, chunk_prefill_size = 8192,
)
def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
total_ops, total_params = module.total_ops.item(), 0
ret_dict = {}
for n, m in module.named_children():
# if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
# m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
# else:
# m_ops, m_params = m.total_ops, m.total_params
next_dict = {}
if m in handler_collection and not isinstance(
m, (nn.Sequential, nn.ModuleList)
):
m_ops, m_params = m.total_ops.item(), m.total_params.item()
else:
m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
ret_dict[n] = (m_ops, m_params, next_dict)
total_ops += m_ops
total_params += m_params
# print(prefix, module._get_name(), (total_ops, total_params))
return total_ops, total_params, ret_dict
total_ops, total_params, ret_dict = dfs_count(model)
# reset model to original status
model.train(prev_training_status)
for m, (op_handler, params_handler) in handler_collection.items():
op_handler.remove()
params_handler.remove()
m._buffers.pop("total_ops")
m._buffers.pop("total_params")
if ret_layer_info:
return total_ops, total_params, ret_dict
return total_ops, total_params