kvcache-ai-ktransformers/ktransformers/util/ascend/ascend_utils.py
2025-07-22 10:58:25 +00:00

210 lines
No EOL
7.9 KiB
Python

import os
from datetime import timedelta
import torch
try:
import torch_npu
except:
Warning("torch_npu not found, please install torch_npu for NPU support.")
import torch.distributed as dist
_DATA_PARALLEL_SIZE = 0
_TENSOR_PARALLEL_SIZE = 0
_DATA_PARALLEL_GROUP = None
_TENSOR_PARALLEL_RANKS = None
_TENSOR_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_GLOO = None
_DATA_PARALLEL_RANKS = None
_GLOBAL_GROUP = None
_LM_HEAD_GROUP = None
def setup_model_parallel(distributed_timeout_minutes: int = 30, tp: int = 1):
global _DATA_PARALLEL_SIZE
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_RANKS
global _TENSOR_PARALLEL_SIZE
global _TENSOR_PARALLEL_RANKS
global _TENSOR_PARALLEL_GROUP
os.environ["MASTER_ADDR"] = "localhost"
local_rank = int(os.getenv("LOCAL_RANK", '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
torch_npu.npu.set_device(local_rank)
tp_size = tp
dp_size = world_size // tp_size
_DATA_PARALLEL_SIZE = dp_size
_TENSOR_PARALLEL_SIZE = tp_size
torch.set_num_threads(8)
timeout = timedelta(minutes=distributed_timeout_minutes)
print(f"start to init process group ------rank is {local_rank}, world_size is {world_size}")
torch.distributed.init_process_group(
backend='hccl',
world_size=world_size, rank=local_rank
)
print(f"init process group success ------rank is {local_rank}, world_size is {world_size}")
rank = torch.distributed.get_rank()
nccl_comm_cfgs = {}
for dp_group_id in range(tp_size):
ranks = list(range(dp_group_id, world_size, tp_size))
dp_group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
)
if rank in ranks:
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = dp_group
_DATA_PARALLEL_RANKS = ranks
for tp_group_id in range(dp_size):
start_rank = tp_group_id * tp_size
end_rank = (tp_group_id + 1) * tp_size
ranks = list(range(start_rank, end_rank))
tp_group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
)
if rank in ranks:
global _TENSOR_PARALLEL_GROUP
_TENSOR_PARALLEL_GROUP = tp_group
_TENSOR_PARALLEL_RANKS = ranks
torch.manual_seed(1)
return local_rank, world_size
def get_tensor_parallel_size():
assert _TENSOR_PARALLEL_SIZE is not None, "tensor parallel size is not set"
return _TENSOR_PARALLEL_SIZE
def get_tensor_parallel_group():
assert _TENSOR_PARALLEL_GROUP is not None, "tensor parallel group is not initialized"
return _TENSOR_PARALLEL_GROUP
def get_tensor_parallel_ranks():
assert _TENSOR_PARALLEL_RANKS is not None, "tensor parallel ranks is not initialized"
return _TENSOR_PARALLEL_RANKS
def get_data_parallel_size():
assert _DATA_PARALLEL_SIZE is not None, "data parallel size is not initialized"
return _DATA_PARALLEL_SIZE
def get_data_parallel_gloo():
assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel gloo group is not initialized"
return _DATA_PARALLEL_GROUP_GLOO
def get_data_parallel_group():
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP
def get_data_parallel_ranks():
assert _DATA_PARALLEL_RANKS is not None, "data parallel ranks is not initialized"
return _DATA_PARALLEL_RANKS
def get_global_group():
assert _GLOBAL_GROUP is not None, "global group is not initialized"
return _GLOBAL_GROUP
def get_nccl_options(pg_name, nccl_comm_cfgs):
if pg_name in nccl_comm_cfgs:
nccl_options = torch.distributed.ProcessGroupNCCL.Options()
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
return nccl_options
else:
return None
def get_safetensors_cut_weight(name: str, weights: torch.Tensor):
translate_col_cut_tensors = ["ffn_down", "attn_output"]
translate_row_cut_tensors = ["ffn_gate", "ffn_up", "attn_q_b"]
translate_lm_cut_tensor = ["output"]
tp = get_tensor_parallel_size()
if tp == 1 or weights.shape == torch.Size([1]):
return weights
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
rank %= tp
assert 0 <= rank < tp and tp > 0, f"unexpected {rank=}, {tp=}"
if any(t in name for t in translate_col_cut_tensors):
if weights.dim() == 1:
return weights
dim = weights.shape[-1]
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
chunk_size = dim // tp
output_weights = weights[:, rank * chunk_size: (rank + 1) * chunk_size]
return output_weights
elif any(t in name for t in translate_row_cut_tensors):
dim = weights.shape[0]
assert dim % tp == 0, f"unexpected division {dim=}, {tp=}"
chunk_size = dim // tp
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
return output_weights
elif (tp > 1) and (any(t in name for t in translate_lm_cut_tensor)):
dim = weights.shape[0]
assert dim % tp == 0, f"unexpected division {dim=} {world_size=}"
chunk_size = dim // tp
output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]
return output_weights
else:
return weights
def get_absort_weight(model, config):
local_rank = torch.distributed.get_rank()
tp = get_tensor_parallel_size()
local_rank %= tp
tp_heads = config.num_attention_heads // tp
for i in range(config.num_hidden_layers):
self = model.model.layers[i].self_attn
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(config.num_attention_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].clone()
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].clone()
q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()
out_absorb = out_absorb.transpose(1, 2).contiguous()
setattr(self, "q_absorb", q_absorb)
setattr(self, "out_absorb", out_absorb)
del self.orig_module.kv_b_proj
def allreduce_wrapper(func):
def wrapper(*args, **kwargs):
orig_output = func(*args, **kwargs)
if isinstance(orig_output, tuple):
if get_tensor_parallel_size() > 1:
org_dtype = orig_output[0].dtype
if org_dtype == torch.bfloat16:
dist.all_reduce(orig_output[0].to(dtype=torch.float16), op=dist.ReduceOp.SUM,
group=get_tensor_parallel_group())
else:
dist.all_reduce(orig_output[0], op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
if org_dtype == torch.bfloat16:
bf_orig_output = orig_output[0].to(dtype=org_dtype)
else:
bf_orig_output = orig_output[0]
else:
bf_orig_output = orig_output[0]
return (bf_orig_output,) + orig_output[1:]
else:
if get_tensor_parallel_size() > 1:
org_dtype = orig_output.dtype
if org_dtype == torch.bfloat16:
orig_output = orig_output.to(dtype=torch.float16)
dist.all_reduce(orig_output, op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())
if org_dtype == torch.bfloat16:
orig_output = orig_output.to(dtype=org_dtype)
return orig_output
return wrapper