mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
206 lines
No EOL
5.8 KiB
Python
206 lines
No EOL
5.8 KiB
Python
|
|
# Adapted from
|
|
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
|
|
# which was originally adapted from
|
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
|
|
|
"""
|
|
Memory-efficient attention for prefill.
|
|
It supporst page size = 1.
|
|
"""
|
|
|
|
# Adapted from
|
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
is_cuda_available = torch.cuda.is_available()
|
|
if is_cuda_available:
|
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
|
|
|
|
|
@triton.jit
|
|
def _fwd_kernel(
|
|
Q,
|
|
K,
|
|
V,
|
|
sm_scale,
|
|
B_Start_Loc,
|
|
B_Seqlen,
|
|
Out,
|
|
stride_qbs,
|
|
stride_qh,
|
|
stride_kbs,
|
|
stride_kh,
|
|
stride_vbs,
|
|
stride_vh,
|
|
stride_obs,
|
|
stride_oh,
|
|
kv_group_num: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
IS_CAUSAL: tl.constexpr,
|
|
Lk: tl.constexpr,
|
|
):
|
|
cur_batch = tl.program_id(0)
|
|
cur_head = tl.program_id(1)
|
|
start_m = tl.program_id(2)
|
|
|
|
cur_kv_head = cur_head // kv_group_num
|
|
|
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
|
|
|
block_start_loc = BLOCK_M * start_m
|
|
|
|
# initialize offsets
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
off_q = (
|
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
|
+ cur_head * stride_qh
|
|
+ offs_d[None, :]
|
|
)
|
|
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
|
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
|
|
|
mask_d = offs_d < Lk
|
|
|
|
q = tl.load(
|
|
Q + off_q,
|
|
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
|
|
other=0.0,
|
|
)
|
|
|
|
k_ptrs = K + off_k
|
|
v_ptrs = V + off_v
|
|
|
|
# initialize pointer to m and l
|
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
|
|
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
|
|
|
end_n = (
|
|
cur_batch_seq_len
|
|
if not IS_CAUSAL
|
|
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
|
|
)
|
|
for start_n in range(0, block_mask * end_n, BLOCK_N):
|
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
# -- compute qk ----
|
|
k = tl.load(
|
|
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
|
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
|
|
other=0.0,
|
|
)
|
|
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
|
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
qk += tl.dot(q, k)
|
|
qk *= sm_scale
|
|
|
|
if IS_CAUSAL:
|
|
qk += tl.where(
|
|
(start_n + offs_n[None, :] < cur_batch_seq_len)
|
|
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
|
|
0,
|
|
float("-inf"),
|
|
)
|
|
else:
|
|
qk += tl.where(
|
|
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
|
|
)
|
|
|
|
# -- compute m_ij, p, l_ij
|
|
m_ij = tl.max(qk, 1)
|
|
p = tl.exp(qk - m_ij[:, None])
|
|
l_ij = tl.sum(p, 1)
|
|
# -- update m_i and l_i
|
|
m_i_new = tl.maximum(m_i, m_ij)
|
|
alpha = tl.exp(m_i - m_i_new)
|
|
beta = tl.exp(m_ij - m_i_new)
|
|
l_i_new = alpha * l_i + beta * l_ij
|
|
# -- update output accumulator --
|
|
# scale p
|
|
p_scale = beta / l_i_new
|
|
p = p * p_scale[:, None]
|
|
# scale acc
|
|
acc_scale = l_i / l_i_new * alpha
|
|
acc = acc * acc_scale[:, None]
|
|
# update acc
|
|
v = tl.load(
|
|
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
|
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
|
|
other=0.0,
|
|
)
|
|
|
|
p = p.to(v.dtype)
|
|
acc += tl.dot(p, v)
|
|
# update m_i and l_i
|
|
l_i = l_i_new
|
|
m_i = m_i_new
|
|
# initialize pointers to output
|
|
off_o = (
|
|
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
|
+ cur_head * stride_oh
|
|
+ offs_d[None, :]
|
|
)
|
|
out_ptrs = Out + off_o
|
|
tl.store(
|
|
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
|
|
)
|
|
|
|
|
|
def context_attention_fwd(
|
|
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
|
):
|
|
"""
|
|
q, k, v: [b * s, head, head_dim]
|
|
b_start_loc: [b]
|
|
b_seq_len: [b]
|
|
out: [b * s, head, head_dim]
|
|
"""
|
|
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
|
BLOCK = 128
|
|
else:
|
|
BLOCK = 64
|
|
|
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
|
|
sm_scale = 1.0 / (Lq**0.5)
|
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
|
kv_group_num = q.shape[1] // k.shape[1]
|
|
|
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
|
num_warps = 4 if Lk <= 64 else 8
|
|
|
|
_fwd_kernel[grid](
|
|
q,
|
|
k,
|
|
v,
|
|
sm_scale,
|
|
b_start_loc,
|
|
b_seq_len,
|
|
o,
|
|
q.stride(0),
|
|
q.stride(1),
|
|
k.stride(0),
|
|
k.stride(1),
|
|
v.stride(0),
|
|
v.stride(1),
|
|
o.stride(0),
|
|
o.stride(1),
|
|
kv_group_num=kv_group_num,
|
|
BLOCK_M=BLOCK,
|
|
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
|
BLOCK_N=BLOCK,
|
|
IS_CAUSAL=is_causal,
|
|
num_warps=num_warps,
|
|
num_stages=1,
|
|
Lk=Lk,
|
|
) |