Update triton_attention.py

This commit is contained in:
Atream 2025-02-15 15:41:01 +08:00 committed by GitHub
parent 1548c99234
commit d90749d35d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,3 +1,9 @@
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
import triton
import triton.language as tl
@ -376,4 +382,4 @@ def decode_attention_fwd_grouped(
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
num_kv_splits)
num_kv_splits)