mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
751 lines
26 KiB
Python
751 lines
26 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
"""
|
|
Description : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference
|
|
with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring
|
|
and managing key-value caches, updating and retrieving cache data, and handling attention
|
|
operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies
|
|
(e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization
|
|
on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration.
|
|
These classes facilitate efficient caching and memory management for deep learning models
|
|
that leverage key-value attention mechanisms, particularly on CPU-based systems.
|
|
Author : djw
|
|
Date : 2024-08-26 23:25:24
|
|
Version : 1.0.0
|
|
LastEditors : djw
|
|
LastEditTime : 2024-08-26 23:25:24
|
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|
"""
|
|
import sys, os
|
|
from typing import Any
|
|
import torch
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
|
import cpuinfer_ext
|
|
from ktransformers.server.config.config import Config
|
|
|
|
|
|
class CPUInferKVCache:
|
|
def __init__(
|
|
self,
|
|
layer_num: int = 32,
|
|
kv_head_num: int = 8,
|
|
q_head_num: int = 32,
|
|
head_dim: int = 128,
|
|
block_len: int = 256,
|
|
anchor_num: int = 4,
|
|
anchor_type: str = "FIXED",
|
|
kv_type: str = "Q4_0",
|
|
retrieval_type: str = "SHARED",
|
|
layer_step: int = 1,
|
|
token_step: int = 1,
|
|
layer_offset: int = 0,
|
|
max_thread_num: int = 32,
|
|
max_batch_size: int = 4,
|
|
max_block_num: int = 512,
|
|
):
|
|
|
|
if anchor_type == "FIXED":
|
|
anchor_type = cpuinfer_ext.kvcache.AnchorType.FIXED
|
|
elif anchor_type == "QUEST":
|
|
anchor_type = cpuinfer_ext.kvcache.AnchorType.QUEST
|
|
elif anchor_type == "DYNAMIC":
|
|
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
|
|
elif anchor_type == "BLOCK_MEAN":
|
|
anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MEAN
|
|
elif anchor_type == "BLOCK_MAX":
|
|
anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MAX
|
|
else:
|
|
raise ValueError(f"Unknown anchor type: {anchor_type}")
|
|
|
|
if kv_type == "FP16":
|
|
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
|
|
elif kv_type == "FP32":
|
|
assert False, "FP32 is not supported yet."
|
|
kv_type = cpuinfer_ext.kvcache.ggml_type.FP32
|
|
elif kv_type == "Q4_0":
|
|
kv_type = cpuinfer_ext.kvcache.ggml_type.Q4_0
|
|
elif kv_type == "Q8_0":
|
|
kv_type = cpuinfer_ext.kvcache.ggml_type.Q8_0
|
|
else:
|
|
raise ValueError(f"Unknown kv type: {kv_type}")
|
|
|
|
if retrieval_type == "SHARED":
|
|
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
|
|
elif retrieval_type == "INDIVIDUAL":
|
|
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.QHEAD
|
|
elif retrieval_type == "SEPARATE":
|
|
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.KVHEAD
|
|
|
|
self.config = cpuinfer_ext.kvcache.KVCacheConfig(
|
|
layer_num,
|
|
kv_head_num,
|
|
q_head_num,
|
|
head_dim,
|
|
block_len,
|
|
anchor_num,
|
|
anchor_type,
|
|
kv_type,
|
|
retrieval_type,
|
|
layer_step,
|
|
token_step,
|
|
layer_offset,
|
|
max_block_num,
|
|
max_batch_size,
|
|
max_thread_num,
|
|
)
|
|
self.kvcache = cpuinfer_ext.kvcache.KVCache(self.config)
|
|
|
|
def load_kvcache(self, tensor_file_path: str):
|
|
if not os.path.exists(tensor_file_path):
|
|
raise FileNotFoundError(f"The file {tensor_file_path} does not exist.")
|
|
return self.kvcache.load_kvcache(tensor_file_path,)
|
|
|
|
def dump_kvcache(
|
|
self, block_table: torch.Tensor, cache_total_len: int, tensor_file_path: str
|
|
):
|
|
assert (
|
|
block_table.dim() == 1
|
|
and block_table.dtype == torch.int
|
|
and block_table.is_contiguous()
|
|
and block_table.device == torch.device("cpu")
|
|
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
block_table.dim(),
|
|
block_table.size(),
|
|
block_table.dtype,
|
|
block_table.is_contiguous(),
|
|
block_table.device,
|
|
)
|
|
|
|
assert (
|
|
cache_total_len > 0
|
|
and cache_total_len <= self.config.block_len * block_table.size(0)
|
|
), "cache_total_len: {}".format(cache_total_len)
|
|
|
|
if not os.path.exists(os.path.dirname(tensor_file_path)):
|
|
os.makedirs(os.path.dirname(tensor_file_path))
|
|
|
|
return self.kvcache.dump_kvcache(
|
|
block_table.data_ptr(),
|
|
cache_total_len,
|
|
tensor_file_path,
|
|
)
|
|
|
|
def update_cache_total_len(self, cache_total_len: int):
|
|
assert cache_total_len > 0, "cache_total_len: {}".format(cache_total_len)
|
|
self.kvcache.update_cache_total_len(cache_total_len)
|
|
|
|
# q_in: (bsz, q_len, q_head_num, head_dim)
|
|
# output: (bsz, q_len, q_head_num, head_dim)
|
|
# attn_lse: (bsz, q_len, q_head_num)
|
|
# block_table: (bsz, max_block_num)
|
|
def attn(
|
|
self,
|
|
q_in: torch.Tensor,
|
|
output: torch.Tensor,
|
|
attn_lse: torch.Tensor,
|
|
layer_idx: int,
|
|
generate_token_idx: int,
|
|
block_table: torch.Tensor | None = None,
|
|
cache_seqlens: torch.Tensor | None = None,
|
|
pick_block_num: int | None = None,
|
|
init_block_num: int | None = None,
|
|
local_block_num: int | None = None,
|
|
):
|
|
|
|
assert (
|
|
q_in.dim() == 4
|
|
and q_in.size(2) == self.config.q_head_num
|
|
and q_in.size(3) == self.config.head_dim
|
|
and q_in.dtype == torch.float16
|
|
and q_in.is_contiguous()
|
|
and q_in.device == torch.device("cpu")
|
|
), "q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
q_in.dim(), q_in.size(), q_in.dtype, q_in.is_contiguous(), q_in.device
|
|
)
|
|
|
|
batch_size = q_in.size(0)
|
|
q_len = q_in.size(1)
|
|
|
|
assert (block_table is None) or (
|
|
block_table.dim() == 2
|
|
and block_table.size(0) == batch_size
|
|
and block_table.dtype == torch.int
|
|
and block_table.is_contiguous()
|
|
and block_table.device == torch.device("cpu")
|
|
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
block_table.dim(),
|
|
block_table.size(),
|
|
block_table.dtype,
|
|
block_table.is_contiguous(),
|
|
block_table.device,
|
|
)
|
|
|
|
max_block_num = block_table.size(1) if block_table is not None else 0
|
|
|
|
assert (
|
|
output.dim() == 4
|
|
and output.size(0) == batch_size
|
|
and output.size(2) == self.config.q_head_num
|
|
and output.size(1) == q_len
|
|
and output.size(3) == self.config.head_dim
|
|
and output.dtype == torch.float16
|
|
and output.is_contiguous()
|
|
and output.device == torch.device("cpu")
|
|
), "output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
output.dim(),
|
|
output.size(),
|
|
output.dtype,
|
|
output.is_contiguous(),
|
|
output.device,
|
|
)
|
|
|
|
assert (
|
|
attn_lse.dim() == 3
|
|
and attn_lse.size(0) == batch_size
|
|
and attn_lse.size(1) == q_len
|
|
and attn_lse.size(2) == self.config.q_head_num
|
|
and attn_lse.dtype == torch.float32
|
|
and attn_lse.is_contiguous()
|
|
and attn_lse.device == torch.device("cpu")
|
|
), "attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
attn_lse.dim(),
|
|
attn_lse.size(),
|
|
attn_lse.dtype,
|
|
attn_lse.is_contiguous(),
|
|
attn_lse.device,
|
|
)
|
|
|
|
assert (
|
|
layer_idx >= 0 and layer_idx < self.config.layer_num
|
|
), "layer_idx: {}".format(layer_idx)
|
|
|
|
assert (cache_seqlens is None) or (
|
|
cache_seqlens.dim() == 1
|
|
and cache_seqlens.size(0) == batch_size
|
|
and cache_seqlens.dtype == torch.int
|
|
and cache_seqlens.is_contiguous()
|
|
and cache_seqlens.device == torch.device("cpu")
|
|
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
cache_seqlens.dim(),
|
|
cache_seqlens.size(),
|
|
cache_seqlens.dtype,
|
|
cache_seqlens.is_contiguous(),
|
|
cache_seqlens.device,
|
|
)
|
|
|
|
return self.kvcache.attn(
|
|
q_in.data_ptr(),
|
|
output.data_ptr(),
|
|
attn_lse.data_ptr(),
|
|
layer_idx,
|
|
generate_token_idx,
|
|
q_len,
|
|
batch_size,
|
|
max_block_num,
|
|
block_table.data_ptr() if block_table is not None else 0,
|
|
cache_seqlens.data_ptr() if cache_seqlens is not None else 0,
|
|
pick_block_num,
|
|
init_block_num,
|
|
local_block_num,
|
|
)
|
|
|
|
# k_in: (block_len, kv_head_num, head_dim)
|
|
# v_in: (block_len, kv_head_num, head_dim)
|
|
def update_kvcache_one_block_fp16(
|
|
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
|
|
):
|
|
assert (
|
|
k_in.dim() == 3
|
|
and k_in.size(1) == self.config.block_len
|
|
and k_in.size(0) == self.config.kv_head_num
|
|
and k_in.size(2) == self.config.head_dim
|
|
and k_in.dtype == torch.float16
|
|
and k_in.is_contiguous()
|
|
and k_in.device == torch.device("cpu")
|
|
), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
|
|
)
|
|
assert (
|
|
v_in.dim() == 3
|
|
and v_in.size(1) == self.config.block_len
|
|
and v_in.size(0) == self.config.kv_head_num
|
|
and v_in.size(2) == self.config.head_dim
|
|
and v_in.dtype == torch.float16
|
|
and v_in.is_contiguous()
|
|
and v_in.device == torch.device("cpu")
|
|
), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
|
|
)
|
|
assert (
|
|
layer_id >= 0 and layer_id < self.config.layer_num
|
|
), "layer_id: {}".format(layer_id)
|
|
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
|
return self.kvcache.update_one_block_fp16(
|
|
k_in.data_ptr(),
|
|
v_in.data_ptr(),
|
|
layer_id,
|
|
block_idx,
|
|
)
|
|
|
|
def get_kvcache_one_block_fp16(
|
|
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
|
|
):
|
|
assert (
|
|
k_in.dim() == 3
|
|
and k_in.size(1) == self.config.block_len
|
|
and k_in.size(0) == self.config.kv_head_num
|
|
and k_in.size(2) == self.config.head_dim
|
|
and k_in.dtype == torch.float16
|
|
and k_in.is_contiguous()
|
|
and k_in.device == torch.device("cpu")
|
|
), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
|
|
)
|
|
assert (
|
|
v_in.dim() == 3
|
|
and v_in.size(1) == self.config.block_len
|
|
and v_in.size(0) == self.config.kv_head_num
|
|
and v_in.size(2) == self.config.head_dim
|
|
and v_in.dtype == torch.float16
|
|
and v_in.is_contiguous()
|
|
and v_in.device == torch.device("cpu")
|
|
), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
|
|
)
|
|
assert (
|
|
layer_id >= 0 and layer_id < self.config.layer_num
|
|
), "layer_id: {}".format(layer_id)
|
|
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
|
return self.kvcache.get_one_block_fp16(
|
|
k_in.data_ptr(),
|
|
v_in.data_ptr(),
|
|
layer_id,
|
|
block_idx,
|
|
)
|
|
|
|
def update_importance_one_block(
|
|
self, importance: torch.Tensor, layer_id: int, block_idx: int
|
|
):
|
|
assert (
|
|
importance.dim() == 1
|
|
and importance.size(0) == self.config.block_len
|
|
and importance.dtype == torch.float16
|
|
and importance.is_contiguous()
|
|
and importance.device == torch.device("cpu")
|
|
), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
importance.dim(),
|
|
importance.size(),
|
|
importance.dtype,
|
|
importance.is_contiguous(),
|
|
importance.device,
|
|
)
|
|
assert (
|
|
layer_id >= 0 and layer_id < self.config.layer_num
|
|
), "layer_id: {}".format(layer_id)
|
|
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
|
return self.kvcache.update_importance_one_block(
|
|
importance.data_ptr(),
|
|
layer_id,
|
|
block_idx,
|
|
)
|
|
|
|
def get_importance_one_block(
|
|
self, importance: torch.Tensor, layer_id: int, block_idx: int
|
|
):
|
|
assert (
|
|
importance.dim() == 1
|
|
and importance.size(0) == self.config.block_len
|
|
and importance.dtype == torch.float16
|
|
and importance.is_contiguous()
|
|
and importance.device == torch.device("cpu")
|
|
), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
importance.dim(),
|
|
importance.size(),
|
|
importance.dtype,
|
|
importance.is_contiguous(),
|
|
importance.device,
|
|
)
|
|
assert (
|
|
layer_id >= 0 and layer_id < self.config.layer_num
|
|
), "layer_id: {}".format(layer_id)
|
|
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
|
return self.kvcache.get_importance_one_block(
|
|
importance.data_ptr(),
|
|
layer_id,
|
|
block_idx,
|
|
)
|
|
|
|
def get_anchor_one_block(self, anchor: torch.Tensor, layer_id: int, block_idx: int):
|
|
assert (
|
|
anchor.dim() == 3
|
|
and anchor.size(0) == self.config.kv_head_num
|
|
and anchor.size(1) == self.config.anchor_num
|
|
and anchor.size(2) == self.config.head_dim
|
|
and anchor.dtype == torch.float16
|
|
and anchor.is_contiguous()
|
|
and anchor.device == torch.device("cpu")
|
|
), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
anchor.dim(),
|
|
anchor.size(),
|
|
anchor.dtype,
|
|
anchor.is_contiguous(),
|
|
anchor.device,
|
|
)
|
|
assert (
|
|
layer_id >= 0 and layer_id < self.config.layer_num
|
|
), "layer_id: {}".format(layer_id)
|
|
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
|
return self.kvcache.get_anchor_one_block(
|
|
anchor.data_ptr(),
|
|
layer_id,
|
|
block_idx,
|
|
)
|
|
|
|
def update_anchor_one_block(
|
|
self, anchor: torch.Tensor, layer_id: int, block_idx: int
|
|
):
|
|
assert (
|
|
anchor.dim() == 3
|
|
and anchor.size(0) == self.config.kv_head_num
|
|
and anchor.size(1) == self.config.anchor_num
|
|
and anchor.size(2) == self.config.head_dim
|
|
and anchor.dtype == torch.float16
|
|
and anchor.is_contiguous()
|
|
and anchor.device == torch.device("cpu")
|
|
), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
anchor.dim(),
|
|
anchor.size(),
|
|
anchor.dtype,
|
|
anchor.is_contiguous(),
|
|
anchor.device,
|
|
)
|
|
assert (
|
|
layer_id >= 0 and layer_id < self.config.layer_num
|
|
), "layer_id: {}".format(layer_id)
|
|
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
|
return self.kvcache.update_anchor_one_block(
|
|
anchor.data_ptr(),
|
|
layer_id,
|
|
block_idx,
|
|
)
|
|
|
|
def calc_anchor_all_layers(
|
|
self,
|
|
block_table: torch.Tensor,
|
|
cache_seqlens: torch.Tensor,
|
|
):
|
|
assert (
|
|
block_table.dim() == 2
|
|
and block_table.size(0) == cache_seqlens.size(0)
|
|
and block_table.dtype == torch.int
|
|
and block_table.is_contiguous()
|
|
and block_table.device == torch.device("cpu")
|
|
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
block_table.dim(),
|
|
block_table.size(),
|
|
block_table.dtype,
|
|
block_table.is_contiguous(),
|
|
block_table.device,
|
|
)
|
|
assert (
|
|
cache_seqlens.dim() == 1
|
|
and cache_seqlens.dtype == torch.int
|
|
and cache_seqlens.is_contiguous()
|
|
and cache_seqlens.device == torch.device("cpu")
|
|
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
cache_seqlens.dim(),
|
|
cache_seqlens.size(),
|
|
cache_seqlens.dtype,
|
|
cache_seqlens.is_contiguous(),
|
|
cache_seqlens.device,
|
|
)
|
|
batch_size = block_table.size(0)
|
|
max_block_num = block_table.size(1)
|
|
return self.kvcache.calc_anchor_all_layers(
|
|
block_table.data_ptr(),
|
|
cache_seqlens.data_ptr(),
|
|
batch_size,
|
|
max_block_num,
|
|
)
|
|
|
|
def clear_importance_all_layers(
|
|
self,
|
|
block_table: torch.Tensor,
|
|
cache_seqlens: torch.Tensor,
|
|
):
|
|
assert (
|
|
block_table.dim() == 2
|
|
and block_table.size(0) == cache_seqlens.size(0)
|
|
and block_table.dtype == torch.int
|
|
and block_table.is_contiguous()
|
|
and block_table.device == torch.device("cpu")
|
|
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
block_table.dim(),
|
|
block_table.size(),
|
|
block_table.dtype,
|
|
block_table.is_contiguous(),
|
|
block_table.device,
|
|
)
|
|
assert (
|
|
cache_seqlens.dim() == 1
|
|
and cache_seqlens.dtype == torch.int
|
|
and cache_seqlens.is_contiguous()
|
|
and cache_seqlens.device == torch.device("cpu")
|
|
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
|
cache_seqlens.dim(),
|
|
cache_seqlens.size(),
|
|
cache_seqlens.dtype,
|
|
cache_seqlens.is_contiguous(),
|
|
cache_seqlens.device,
|
|
)
|
|
batch_size = block_table.size(0)
|
|
max_block_num = block_table.size(1)
|
|
return self.kvcache.clear_importance_all_layers(
|
|
block_table.data_ptr(),
|
|
cache_seqlens.data_ptr(),
|
|
batch_size,
|
|
max_block_num,
|
|
)
|
|
|
|
def get_cache_total_len(self):
|
|
return self.kvcache.get_cache_total_len()
|
|
|
|
def update_kvcache_q4(
|
|
self,
|
|
k_in: torch.Tensor,
|
|
k_scales: torch.Tensor,
|
|
v_in: torch.Tensor,
|
|
v_scales: torch.Tensor,
|
|
layer_id: int,
|
|
seq_offset: int | None = None,
|
|
seq_len: int | None = None,
|
|
block_table: torch.Tensor | None = None,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
def update_kvcache_fp16(
|
|
self,
|
|
k_in: torch.Tensor,
|
|
v_in: torch.Tensor,
|
|
layer_idx,
|
|
block_table: torch.Tensor,
|
|
max_block_num,
|
|
past_len: torch.Tensor,
|
|
q_len,
|
|
):
|
|
batch_size = block_table.size(0)
|
|
return self.kvcache.get_kvcache_fp16(
|
|
k_in.data_ptr(),
|
|
v_in.data_ptr(),
|
|
layer_idx,
|
|
block_table.data_ptr(),
|
|
batch_size,
|
|
max_block_num,
|
|
past_len.data_ptr(),
|
|
q_len
|
|
)
|
|
|
|
def get_kvcache_q4(
|
|
self,
|
|
k_in: torch.Tensor,
|
|
k_scales: torch.Tensor,
|
|
v_in: torch.Tensor,
|
|
v_scales: torch.Tensor,
|
|
layer_id: int,
|
|
seq_offset: int | None = None,
|
|
seq_len: int | None = None,
|
|
block_table: torch.Tensor | None = None,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
def get_kvcache_fp16(
|
|
self,
|
|
k_in: torch.Tensor,
|
|
v_in: torch.Tensor,
|
|
layer_id: int,
|
|
layer_idx,
|
|
block_table: torch.Tensor,
|
|
max_block_num,
|
|
past_len: torch.Tensor,
|
|
):
|
|
batch_size = block_table.size(0)
|
|
return self.kvcache.get_kvcache_fp16(
|
|
k_in.data_ptr(),
|
|
v_in.data_ptr(),
|
|
layer_idx,
|
|
block_table.data_ptr(),
|
|
batch_size,
|
|
max_block_num,
|
|
past_len.data_ptr(),
|
|
)
|
|
|
|
def get_and_update_kvcache_fp16(
|
|
self,
|
|
k_cache_cpu: torch.Tensor,
|
|
v_cache_cpu: torch.Tensor,
|
|
layer_idx,
|
|
block_table: torch.Tensor,
|
|
max_block_num,
|
|
past_len: torch.Tensor,
|
|
q_len,
|
|
):
|
|
batch_size = block_table.size(0)
|
|
return self.kvcache.get_and_update_kvcache_fp16(
|
|
k_cache_cpu.data_ptr(),
|
|
v_cache_cpu.data_ptr(),
|
|
layer_idx,
|
|
block_table.data_ptr(),
|
|
batch_size,
|
|
max_block_num,
|
|
past_len.data_ptr(),
|
|
q_len,
|
|
)
|
|
|
|
def update_importance(
|
|
self,
|
|
importance_cache: torch.Tensor,
|
|
layer_idx,
|
|
block_table: torch.Tensor,
|
|
max_block_num,
|
|
offset: torch.Tensor,
|
|
width,
|
|
):
|
|
batch_size = block_table.size(0)
|
|
return self.kvcache.update_importance(
|
|
importance_cache.data_ptr(),
|
|
layer_idx,
|
|
block_table.data_ptr(),
|
|
batch_size,
|
|
max_block_num,
|
|
offset.data_ptr(),
|
|
width,
|
|
)
|
|
|
|
# attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32)
|
|
def get_attn_sparsity(
|
|
self,
|
|
q_in: torch.Tensor,
|
|
attn_sparsity: torch.Tensor,
|
|
layer_idx: int,
|
|
block_table: torch.Tensor,
|
|
cache_seqlens: torch.Tensor,
|
|
block_table_origin: torch.Tensor,
|
|
cache_seqlens_origin: torch.Tensor,
|
|
generate_token_idx: int = 0,
|
|
topk: int | None = None,
|
|
local: int | None = None,
|
|
):
|
|
batch_size = block_table.size(0)
|
|
max_block_num = block_table.size(1)
|
|
max_block_num_origin = block_table_origin.size(1)
|
|
q_len = q_in.size(1)
|
|
|
|
if topk is None or local is None or topk + local >= max_block_num:
|
|
topk = -1
|
|
local = -1
|
|
return self.kvcache.get_attn_sparsity(
|
|
q_in.data_ptr(),
|
|
attn_sparsity.data_ptr(),
|
|
layer_idx,
|
|
generate_token_idx,
|
|
q_len,
|
|
batch_size,
|
|
max_block_num,
|
|
block_table.data_ptr(),
|
|
cache_seqlens.data_ptr(),
|
|
block_table_origin.data_ptr(),
|
|
cache_seqlens_origin.data_ptr(),
|
|
max_block_num_origin,
|
|
topk,
|
|
local,
|
|
)
|
|
|
|
def attn_with_kvcache(
|
|
self,
|
|
q_in: torch.Tensor,
|
|
k_in: torch.Tensor,
|
|
v_in: torch.Tensor,
|
|
output: torch.Tensor,
|
|
attn_lse: torch.Tensor,
|
|
layer_idx: int,
|
|
block_table: torch.Tensor,
|
|
cache_seqlens: torch.Tensor,
|
|
generate_token_idx: int = 0,
|
|
topk: int | None = None,
|
|
local: int | None = None,
|
|
):
|
|
|
|
batch_size = block_table.size(0)
|
|
max_block_num = block_table.size(1)
|
|
q_len = q_in.size(1)
|
|
|
|
if topk is None or local is None or topk + local >= max_block_num:
|
|
topk = -1
|
|
local = -1
|
|
return self.kvcache.attn_with_kvcache(
|
|
q_in.data_ptr(),
|
|
k_in.data_ptr(),
|
|
v_in.data_ptr(),
|
|
output.data_ptr(),
|
|
attn_lse.data_ptr(),
|
|
layer_idx,
|
|
generate_token_idx,
|
|
q_len,
|
|
batch_size,
|
|
max_block_num,
|
|
block_table.data_ptr(),
|
|
cache_seqlens.data_ptr(),
|
|
topk,
|
|
local,
|
|
)
|
|
|
|
def get_all_kvcache_one_layer(
|
|
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int
|
|
):
|
|
return self.kvcache.get_all_kvcache_one_layer(
|
|
k_in.data_ptr(),
|
|
v_in.data_ptr(),
|
|
layer_id,
|
|
)
|
|
|
|
def get_importance(
|
|
self,
|
|
importance: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
def get_anchor(
|
|
self,
|
|
anchor: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
|
|
class CPUInfer:
|
|
cpuinfer = None
|
|
cur_backend_thread_num = 0
|
|
|
|
def __init__(self, thread_num):
|
|
if thread_num > CPUInfer.cur_backend_thread_num:
|
|
CPUInfer.cur_backend_thread_num = thread_num
|
|
del CPUInfer.cpuinfer
|
|
CPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num)
|
|
|
|
def submit(self, task):
|
|
CPUInfer.cpuinfer.submit(task)
|
|
|
|
def submit_with_cuda_stream(self, current_cuda_stream, task):
|
|
CPUInfer.cpuinfer.submit_with_cuda_stream(current_cuda_stream, task)
|
|
|
|
def sync(self):
|
|
CPUInfer.cpuinfer.sync()
|
|
|
|
def sync_with_cuda_stream(self, current_cuda_stream):
|
|
CPUInfer.cpuinfer.sync_with_cuda_stream(current_cuda_stream)
|
|
|
|
|
|
|