mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-04 22:51:51 +00:00
[feature] release 0.1.3
This commit is contained in:
parent
67f8b370c3
commit
4d1d561d28
58 changed files with 11709 additions and 374 deletions
|
|
@ -1,18 +1,746 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""
|
||||
Description : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference
|
||||
with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring
|
||||
and managing key-value caches, updating and retrieving cache data, and handling attention
|
||||
operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies
|
||||
(e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization
|
||||
on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration.
|
||||
These classes facilitate efficient caching and memory management for deep learning models
|
||||
that leverage key-value attention mechanisms, particularly on CPU-based systems.
|
||||
Author : djw
|
||||
Date : 2024-08-26 23:25:24
|
||||
Version : 1.0.0
|
||||
LastEditors : djw
|
||||
LastEditTime : 2024-08-26 23:25:24
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
"""
|
||||
import sys, os
|
||||
from typing import Any
|
||||
import torch
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
||||
import cpuinfer_ext
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
|
||||
class CPUInferKVCache:
|
||||
def __init__(
|
||||
self,
|
||||
layer_num: int = 32,
|
||||
kv_head_num: int = 8,
|
||||
q_head_num: int = 32,
|
||||
head_dim: int = 128,
|
||||
block_len: int = 256,
|
||||
anchor_num: int = 4,
|
||||
anchor_type: str = "FIXED",
|
||||
kv_type: str = "Q4_0",
|
||||
retrieval_type: str = "SHARED",
|
||||
layer_step: int = 1,
|
||||
token_step: int = 1,
|
||||
layer_offset: int = 0,
|
||||
max_thread_num: int = 32,
|
||||
max_batch_size: int = 4,
|
||||
max_block_num: int = 512,
|
||||
):
|
||||
|
||||
if anchor_type == "FIXED":
|
||||
anchor_type = cpuinfer_ext.kvcache.AnchorType.FIXED
|
||||
elif anchor_type == "QUEST":
|
||||
anchor_type = cpuinfer_ext.kvcache.AnchorType.QUEST
|
||||
elif anchor_type == "DYNAMIC":
|
||||
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
|
||||
elif anchor_type == "BLOCK_MEAN":
|
||||
anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MEAN
|
||||
elif anchor_type == "BLOCK_MAX":
|
||||
anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MAX
|
||||
else:
|
||||
raise ValueError(f"Unknown anchor type: {anchor_type}")
|
||||
|
||||
if kv_type == "FP16":
|
||||
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
|
||||
elif kv_type == "FP32":
|
||||
assert False, "FP32 is not supported yet."
|
||||
kv_type = cpuinfer_ext.kvcache.ggml_type.FP32
|
||||
elif kv_type == "Q4_0":
|
||||
kv_type = cpuinfer_ext.kvcache.ggml_type.Q4_0
|
||||
elif kv_type == "Q8_0":
|
||||
kv_type = cpuinfer_ext.kvcache.ggml_type.Q8_0
|
||||
else:
|
||||
raise ValueError(f"Unknown kv type: {kv_type}")
|
||||
|
||||
if retrieval_type == "SHARED":
|
||||
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
|
||||
elif retrieval_type == "INDIVIDUAL":
|
||||
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.QHEAD
|
||||
elif retrieval_type == "SEPARATE":
|
||||
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.KVHEAD
|
||||
|
||||
self.config = cpuinfer_ext.kvcache.KVCacheConfig(
|
||||
layer_num,
|
||||
kv_head_num,
|
||||
q_head_num,
|
||||
head_dim,
|
||||
block_len,
|
||||
anchor_num,
|
||||
anchor_type,
|
||||
kv_type,
|
||||
retrieval_type,
|
||||
layer_step,
|
||||
token_step,
|
||||
layer_offset,
|
||||
max_block_num,
|
||||
max_batch_size,
|
||||
max_thread_num,
|
||||
)
|
||||
self.kvcache = cpuinfer_ext.kvcache.KVCache(self.config)
|
||||
|
||||
def load_kvcache(self, tensor_file_path: str):
|
||||
if not os.path.exists(tensor_file_path):
|
||||
raise FileNotFoundError(f"The file {tensor_file_path} does not exist.")
|
||||
return self.kvcache.load_kvcache(tensor_file_path,)
|
||||
|
||||
def dump_kvcache(
|
||||
self, block_table: torch.Tensor, cache_total_len: int, tensor_file_path: str
|
||||
):
|
||||
assert (
|
||||
block_table.dim() == 1
|
||||
and block_table.dtype == torch.int
|
||||
and block_table.is_contiguous()
|
||||
and block_table.device == torch.device("cpu")
|
||||
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
block_table.dim(),
|
||||
block_table.size(),
|
||||
block_table.dtype,
|
||||
block_table.is_contiguous(),
|
||||
block_table.device,
|
||||
)
|
||||
|
||||
assert (
|
||||
cache_total_len > 0
|
||||
and cache_total_len <= self.config.block_len * block_table.size(0)
|
||||
), "cache_total_len: {}".format(cache_total_len)
|
||||
|
||||
if not os.path.exists(os.path.dirname(tensor_file_path)):
|
||||
os.makedirs(os.path.dirname(tensor_file_path))
|
||||
|
||||
return self.kvcache.dump_kvcache(
|
||||
block_table.data_ptr(),
|
||||
cache_total_len,
|
||||
tensor_file_path,
|
||||
)
|
||||
|
||||
def update_cache_total_len(self, cache_total_len: int):
|
||||
assert cache_total_len > 0, "cache_total_len: {}".format(cache_total_len)
|
||||
self.kvcache.update_cache_total_len(cache_total_len)
|
||||
|
||||
# q_in: (bsz, q_len, q_head_num, head_dim)
|
||||
# output: (bsz, q_len, q_head_num, head_dim)
|
||||
# attn_lse: (bsz, q_len, q_head_num)
|
||||
# block_table: (bsz, max_block_num)
|
||||
def attn(
|
||||
self,
|
||||
q_in: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_lse: torch.Tensor,
|
||||
layer_idx: int,
|
||||
generate_token_idx: int,
|
||||
block_table: torch.Tensor | None = None,
|
||||
cache_seqlens: torch.Tensor | None = None,
|
||||
pick_block_num: int | None = None,
|
||||
init_block_num: int | None = None,
|
||||
local_block_num: int | None = None,
|
||||
):
|
||||
|
||||
assert (
|
||||
q_in.dim() == 4
|
||||
and q_in.size(2) == self.config.q_head_num
|
||||
and q_in.size(3) == self.config.head_dim
|
||||
and q_in.dtype == torch.float16
|
||||
and q_in.is_contiguous()
|
||||
and q_in.device == torch.device("cpu")
|
||||
), "q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
q_in.dim(), q_in.size(), q_in.dtype, q_in.is_contiguous(), q_in.device
|
||||
)
|
||||
|
||||
batch_size = q_in.size(0)
|
||||
q_len = q_in.size(1)
|
||||
|
||||
assert (block_table is None) or (
|
||||
block_table.dim() == 2
|
||||
and block_table.size(0) == batch_size
|
||||
and block_table.dtype == torch.int
|
||||
and block_table.is_contiguous()
|
||||
and block_table.device == torch.device("cpu")
|
||||
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
block_table.dim(),
|
||||
block_table.size(),
|
||||
block_table.dtype,
|
||||
block_table.is_contiguous(),
|
||||
block_table.device,
|
||||
)
|
||||
|
||||
max_block_num = block_table.size(1) if block_table is not None else 0
|
||||
|
||||
assert (
|
||||
output.dim() == 4
|
||||
and output.size(0) == batch_size
|
||||
and output.size(2) == self.config.q_head_num
|
||||
and output.size(1) == q_len
|
||||
and output.size(3) == self.config.head_dim
|
||||
and output.dtype == torch.float16
|
||||
and output.is_contiguous()
|
||||
and output.device == torch.device("cpu")
|
||||
), "output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
output.dim(),
|
||||
output.size(),
|
||||
output.dtype,
|
||||
output.is_contiguous(),
|
||||
output.device,
|
||||
)
|
||||
|
||||
assert (
|
||||
attn_lse.dim() == 3
|
||||
and attn_lse.size(0) == batch_size
|
||||
and attn_lse.size(1) == q_len
|
||||
and attn_lse.size(2) == self.config.q_head_num
|
||||
and attn_lse.dtype == torch.float32
|
||||
and attn_lse.is_contiguous()
|
||||
and attn_lse.device == torch.device("cpu")
|
||||
), "attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
attn_lse.dim(),
|
||||
attn_lse.size(),
|
||||
attn_lse.dtype,
|
||||
attn_lse.is_contiguous(),
|
||||
attn_lse.device,
|
||||
)
|
||||
|
||||
assert (
|
||||
layer_idx >= 0 and layer_idx < self.config.layer_num
|
||||
), "layer_idx: {}".format(layer_idx)
|
||||
|
||||
assert (cache_seqlens is None) or (
|
||||
cache_seqlens.dim() == 1
|
||||
and cache_seqlens.size(0) == batch_size
|
||||
and cache_seqlens.dtype == torch.int
|
||||
and cache_seqlens.is_contiguous()
|
||||
and cache_seqlens.device == torch.device("cpu")
|
||||
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
cache_seqlens.dim(),
|
||||
cache_seqlens.size(),
|
||||
cache_seqlens.dtype,
|
||||
cache_seqlens.is_contiguous(),
|
||||
cache_seqlens.device,
|
||||
)
|
||||
|
||||
return self.kvcache.attn(
|
||||
q_in.data_ptr(),
|
||||
output.data_ptr(),
|
||||
attn_lse.data_ptr(),
|
||||
layer_idx,
|
||||
generate_token_idx,
|
||||
q_len,
|
||||
batch_size,
|
||||
max_block_num,
|
||||
block_table.data_ptr() if block_table is not None else 0,
|
||||
cache_seqlens.data_ptr() if cache_seqlens is not None else 0,
|
||||
pick_block_num,
|
||||
init_block_num,
|
||||
local_block_num,
|
||||
)
|
||||
|
||||
# k_in: (block_len, kv_head_num, head_dim)
|
||||
# v_in: (block_len, kv_head_num, head_dim)
|
||||
def update_kvcache_one_block_fp16(
|
||||
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
|
||||
):
|
||||
assert (
|
||||
k_in.dim() == 3
|
||||
and k_in.size(1) == self.config.block_len
|
||||
and k_in.size(0) == self.config.kv_head_num
|
||||
and k_in.size(2) == self.config.head_dim
|
||||
and k_in.dtype == torch.float16
|
||||
and k_in.is_contiguous()
|
||||
and k_in.device == torch.device("cpu")
|
||||
), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
|
||||
)
|
||||
assert (
|
||||
v_in.dim() == 3
|
||||
and v_in.size(1) == self.config.block_len
|
||||
and v_in.size(0) == self.config.kv_head_num
|
||||
and v_in.size(2) == self.config.head_dim
|
||||
and v_in.dtype == torch.float16
|
||||
and v_in.is_contiguous()
|
||||
and v_in.device == torch.device("cpu")
|
||||
), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
|
||||
)
|
||||
assert (
|
||||
layer_id >= 0 and layer_id < self.config.layer_num
|
||||
), "layer_id: {}".format(layer_id)
|
||||
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
||||
return self.kvcache.update_one_block_fp16(
|
||||
k_in.data_ptr(),
|
||||
v_in.data_ptr(),
|
||||
layer_id,
|
||||
block_idx,
|
||||
)
|
||||
|
||||
def get_kvcache_one_block_fp16(
|
||||
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
|
||||
):
|
||||
assert (
|
||||
k_in.dim() == 3
|
||||
and k_in.size(1) == self.config.block_len
|
||||
and k_in.size(0) == self.config.kv_head_num
|
||||
and k_in.size(2) == self.config.head_dim
|
||||
and k_in.dtype == torch.float16
|
||||
and k_in.is_contiguous()
|
||||
and k_in.device == torch.device("cpu")
|
||||
), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
|
||||
)
|
||||
assert (
|
||||
v_in.dim() == 3
|
||||
and v_in.size(1) == self.config.block_len
|
||||
and v_in.size(0) == self.config.kv_head_num
|
||||
and v_in.size(2) == self.config.head_dim
|
||||
and v_in.dtype == torch.float16
|
||||
and v_in.is_contiguous()
|
||||
and v_in.device == torch.device("cpu")
|
||||
), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
|
||||
)
|
||||
assert (
|
||||
layer_id >= 0 and layer_id < self.config.layer_num
|
||||
), "layer_id: {}".format(layer_id)
|
||||
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
||||
return self.kvcache.get_one_block_fp16(
|
||||
k_in.data_ptr(),
|
||||
v_in.data_ptr(),
|
||||
layer_id,
|
||||
block_idx,
|
||||
)
|
||||
|
||||
def update_importance_one_block(
|
||||
self, importance: torch.Tensor, layer_id: int, block_idx: int
|
||||
):
|
||||
assert (
|
||||
importance.dim() == 1
|
||||
and importance.size(0) == self.config.block_len
|
||||
and importance.dtype == torch.float16
|
||||
and importance.is_contiguous()
|
||||
and importance.device == torch.device("cpu")
|
||||
), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
importance.dim(),
|
||||
importance.size(),
|
||||
importance.dtype,
|
||||
importance.is_contiguous(),
|
||||
importance.device,
|
||||
)
|
||||
assert (
|
||||
layer_id >= 0 and layer_id < self.config.layer_num
|
||||
), "layer_id: {}".format(layer_id)
|
||||
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
||||
return self.kvcache.update_importance_one_block(
|
||||
importance.data_ptr(),
|
||||
layer_id,
|
||||
block_idx,
|
||||
)
|
||||
|
||||
def get_importance_one_block(
|
||||
self, importance: torch.Tensor, layer_id: int, block_idx: int
|
||||
):
|
||||
assert (
|
||||
importance.dim() == 1
|
||||
and importance.size(0) == self.config.block_len
|
||||
and importance.dtype == torch.float16
|
||||
and importance.is_contiguous()
|
||||
and importance.device == torch.device("cpu")
|
||||
), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
importance.dim(),
|
||||
importance.size(),
|
||||
importance.dtype,
|
||||
importance.is_contiguous(),
|
||||
importance.device,
|
||||
)
|
||||
assert (
|
||||
layer_id >= 0 and layer_id < self.config.layer_num
|
||||
), "layer_id: {}".format(layer_id)
|
||||
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
||||
return self.kvcache.get_importance_one_block(
|
||||
importance.data_ptr(),
|
||||
layer_id,
|
||||
block_idx,
|
||||
)
|
||||
|
||||
def get_anchor_one_block(self, anchor: torch.Tensor, layer_id: int, block_idx: int):
|
||||
assert (
|
||||
anchor.dim() == 3
|
||||
and anchor.size(0) == self.config.kv_head_num
|
||||
and anchor.size(1) == self.config.anchor_num
|
||||
and anchor.size(2) == self.config.head_dim
|
||||
and anchor.dtype == torch.float16
|
||||
and anchor.is_contiguous()
|
||||
and anchor.device == torch.device("cpu")
|
||||
), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
anchor.dim(),
|
||||
anchor.size(),
|
||||
anchor.dtype,
|
||||
anchor.is_contiguous(),
|
||||
anchor.device,
|
||||
)
|
||||
assert (
|
||||
layer_id >= 0 and layer_id < self.config.layer_num
|
||||
), "layer_id: {}".format(layer_id)
|
||||
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
||||
return self.kvcache.get_anchor_one_block(
|
||||
anchor.data_ptr(),
|
||||
layer_id,
|
||||
block_idx,
|
||||
)
|
||||
|
||||
def update_anchor_one_block(
|
||||
self, anchor: torch.Tensor, layer_id: int, block_idx: int
|
||||
):
|
||||
assert (
|
||||
anchor.dim() == 3
|
||||
and anchor.size(0) == self.config.kv_head_num
|
||||
and anchor.size(1) == self.config.anchor_num
|
||||
and anchor.size(2) == self.config.head_dim
|
||||
and anchor.dtype == torch.float16
|
||||
and anchor.is_contiguous()
|
||||
and anchor.device == torch.device("cpu")
|
||||
), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
anchor.dim(),
|
||||
anchor.size(),
|
||||
anchor.dtype,
|
||||
anchor.is_contiguous(),
|
||||
anchor.device,
|
||||
)
|
||||
assert (
|
||||
layer_id >= 0 and layer_id < self.config.layer_num
|
||||
), "layer_id: {}".format(layer_id)
|
||||
assert block_idx >= 0, "block_idx: {}".format(block_idx)
|
||||
return self.kvcache.update_anchor_one_block(
|
||||
anchor.data_ptr(),
|
||||
layer_id,
|
||||
block_idx,
|
||||
)
|
||||
|
||||
def calc_anchor_all_layers(
|
||||
self,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
):
|
||||
assert (
|
||||
block_table.dim() == 2
|
||||
and block_table.size(0) == cache_seqlens.size(0)
|
||||
and block_table.dtype == torch.int
|
||||
and block_table.is_contiguous()
|
||||
and block_table.device == torch.device("cpu")
|
||||
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
block_table.dim(),
|
||||
block_table.size(),
|
||||
block_table.dtype,
|
||||
block_table.is_contiguous(),
|
||||
block_table.device,
|
||||
)
|
||||
assert (
|
||||
cache_seqlens.dim() == 1
|
||||
and cache_seqlens.dtype == torch.int
|
||||
and cache_seqlens.is_contiguous()
|
||||
and cache_seqlens.device == torch.device("cpu")
|
||||
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
cache_seqlens.dim(),
|
||||
cache_seqlens.size(),
|
||||
cache_seqlens.dtype,
|
||||
cache_seqlens.is_contiguous(),
|
||||
cache_seqlens.device,
|
||||
)
|
||||
batch_size = block_table.size(0)
|
||||
max_block_num = block_table.size(1)
|
||||
return self.kvcache.calc_anchor_all_layers(
|
||||
block_table.data_ptr(),
|
||||
cache_seqlens.data_ptr(),
|
||||
batch_size,
|
||||
max_block_num,
|
||||
)
|
||||
|
||||
def clear_importance_all_layers(
|
||||
self,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
):
|
||||
assert (
|
||||
block_table.dim() == 2
|
||||
and block_table.size(0) == cache_seqlens.size(0)
|
||||
and block_table.dtype == torch.int
|
||||
and block_table.is_contiguous()
|
||||
and block_table.device == torch.device("cpu")
|
||||
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
block_table.dim(),
|
||||
block_table.size(),
|
||||
block_table.dtype,
|
||||
block_table.is_contiguous(),
|
||||
block_table.device,
|
||||
)
|
||||
assert (
|
||||
cache_seqlens.dim() == 1
|
||||
and cache_seqlens.dtype == torch.int
|
||||
and cache_seqlens.is_contiguous()
|
||||
and cache_seqlens.device == torch.device("cpu")
|
||||
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
|
||||
cache_seqlens.dim(),
|
||||
cache_seqlens.size(),
|
||||
cache_seqlens.dtype,
|
||||
cache_seqlens.is_contiguous(),
|
||||
cache_seqlens.device,
|
||||
)
|
||||
batch_size = block_table.size(0)
|
||||
max_block_num = block_table.size(1)
|
||||
return self.kvcache.clear_importance_all_layers(
|
||||
block_table.data_ptr(),
|
||||
cache_seqlens.data_ptr(),
|
||||
batch_size,
|
||||
max_block_num,
|
||||
)
|
||||
|
||||
def get_cache_total_len(self):
|
||||
return self.kvcache.get_cache_total_len()
|
||||
|
||||
def update_kvcache_q4(
|
||||
self,
|
||||
k_in: torch.Tensor,
|
||||
k_scales: torch.Tensor,
|
||||
v_in: torch.Tensor,
|
||||
v_scales: torch.Tensor,
|
||||
layer_id: int,
|
||||
seq_offset: int | None = None,
|
||||
seq_len: int | None = None,
|
||||
block_table: torch.Tensor | None = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_kvcache_fp16(
|
||||
self,
|
||||
k_in: torch.Tensor,
|
||||
v_in: torch.Tensor,
|
||||
layer_idx,
|
||||
block_table: torch.Tensor,
|
||||
max_block_num,
|
||||
past_len: torch.Tensor,
|
||||
q_len,
|
||||
):
|
||||
batch_size = block_table.size(0)
|
||||
return self.kvcache.get_kvcache_fp16(
|
||||
k_in.data_ptr(),
|
||||
v_in.data_ptr(),
|
||||
layer_idx,
|
||||
block_table.data_ptr(),
|
||||
batch_size,
|
||||
max_block_num,
|
||||
past_len.data_ptr(),
|
||||
q_len
|
||||
)
|
||||
|
||||
def get_kvcache_q4(
|
||||
self,
|
||||
k_in: torch.Tensor,
|
||||
k_scales: torch.Tensor,
|
||||
v_in: torch.Tensor,
|
||||
v_scales: torch.Tensor,
|
||||
layer_id: int,
|
||||
seq_offset: int | None = None,
|
||||
seq_len: int | None = None,
|
||||
block_table: torch.Tensor | None = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kvcache_fp16(
|
||||
self,
|
||||
k_in: torch.Tensor,
|
||||
v_in: torch.Tensor,
|
||||
layer_id: int,
|
||||
layer_idx,
|
||||
block_table: torch.Tensor,
|
||||
max_block_num,
|
||||
past_len: torch.Tensor,
|
||||
):
|
||||
batch_size = block_table.size(0)
|
||||
return self.kvcache.get_kvcache_fp16(
|
||||
k_in.data_ptr(),
|
||||
v_in.data_ptr(),
|
||||
layer_idx,
|
||||
block_table.data_ptr(),
|
||||
batch_size,
|
||||
max_block_num,
|
||||
past_len.data_ptr(),
|
||||
)
|
||||
|
||||
def get_and_update_kvcache_fp16(
|
||||
self,
|
||||
k_cache_cpu: torch.Tensor,
|
||||
v_cache_cpu: torch.Tensor,
|
||||
layer_idx,
|
||||
block_table: torch.Tensor,
|
||||
max_block_num,
|
||||
past_len: torch.Tensor,
|
||||
q_len,
|
||||
):
|
||||
batch_size = block_table.size(0)
|
||||
return self.kvcache.get_and_update_kvcache_fp16(
|
||||
k_cache_cpu.data_ptr(),
|
||||
v_cache_cpu.data_ptr(),
|
||||
layer_idx,
|
||||
block_table.data_ptr(),
|
||||
batch_size,
|
||||
max_block_num,
|
||||
past_len.data_ptr(),
|
||||
q_len,
|
||||
)
|
||||
|
||||
def update_importance(
|
||||
self,
|
||||
importance_cache: torch.Tensor,
|
||||
layer_idx,
|
||||
block_table: torch.Tensor,
|
||||
max_block_num,
|
||||
offset: torch.Tensor,
|
||||
width,
|
||||
):
|
||||
batch_size = block_table.size(0)
|
||||
return self.kvcache.update_importance(
|
||||
importance_cache.data_ptr(),
|
||||
layer_idx,
|
||||
block_table.data_ptr(),
|
||||
batch_size,
|
||||
max_block_num,
|
||||
offset.data_ptr(),
|
||||
width,
|
||||
)
|
||||
|
||||
# attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32)
|
||||
def get_attn_sparsity(
|
||||
self,
|
||||
q_in: torch.Tensor,
|
||||
attn_sparsity: torch.Tensor,
|
||||
layer_idx: int,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
block_table_origin: torch.Tensor,
|
||||
cache_seqlens_origin: torch.Tensor,
|
||||
generate_token_idx: int = 0,
|
||||
topk: int | None = None,
|
||||
local: int | None = None,
|
||||
):
|
||||
batch_size = block_table.size(0)
|
||||
max_block_num = block_table.size(1)
|
||||
max_block_num_origin = block_table_origin.size(1)
|
||||
q_len = q_in.size(1)
|
||||
|
||||
if topk is None or local is None or topk + local >= max_block_num:
|
||||
topk = -1
|
||||
local = -1
|
||||
return self.kvcache.get_attn_sparsity(
|
||||
q_in.data_ptr(),
|
||||
attn_sparsity.data_ptr(),
|
||||
layer_idx,
|
||||
generate_token_idx,
|
||||
q_len,
|
||||
batch_size,
|
||||
max_block_num,
|
||||
block_table.data_ptr(),
|
||||
cache_seqlens.data_ptr(),
|
||||
block_table_origin.data_ptr(),
|
||||
cache_seqlens_origin.data_ptr(),
|
||||
max_block_num_origin,
|
||||
topk,
|
||||
local,
|
||||
)
|
||||
|
||||
def attn_with_kvcache(
|
||||
self,
|
||||
q_in: torch.Tensor,
|
||||
k_in: torch.Tensor,
|
||||
v_in: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_lse: torch.Tensor,
|
||||
layer_idx: int,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
generate_token_idx: int = 0,
|
||||
topk: int | None = None,
|
||||
local: int | None = None,
|
||||
):
|
||||
|
||||
batch_size = block_table.size(0)
|
||||
max_block_num = block_table.size(1)
|
||||
q_len = q_in.size(1)
|
||||
|
||||
if topk is None or local is None or topk + local >= max_block_num:
|
||||
topk = -1
|
||||
local = -1
|
||||
return self.kvcache.attn_with_kvcache(
|
||||
q_in.data_ptr(),
|
||||
k_in.data_ptr(),
|
||||
v_in.data_ptr(),
|
||||
output.data_ptr(),
|
||||
attn_lse.data_ptr(),
|
||||
layer_idx,
|
||||
generate_token_idx,
|
||||
q_len,
|
||||
batch_size,
|
||||
max_block_num,
|
||||
block_table.data_ptr(),
|
||||
cache_seqlens.data_ptr(),
|
||||
topk,
|
||||
local,
|
||||
)
|
||||
|
||||
def get_all_kvcache_one_layer(
|
||||
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int
|
||||
):
|
||||
return self.kvcache.get_all_kvcache_one_layer(
|
||||
k_in.data_ptr(),
|
||||
v_in.data_ptr(),
|
||||
layer_id,
|
||||
)
|
||||
|
||||
def get_importance(
|
||||
self,
|
||||
importance: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_anchor(
|
||||
self,
|
||||
anchor: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CPUInfer:
|
||||
cpu_infer = None
|
||||
def __init__(self, cpu_infer:int = Config().cpu_infer):
|
||||
if CPUInfer.cpu_infer is None:
|
||||
CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer)
|
||||
cpuinfer = None
|
||||
def __init__(self, thread_num):
|
||||
CPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num)
|
||||
|
||||
def submit(self, task):
|
||||
CPUInfer.cpuinfer.submit(task)
|
||||
|
||||
def submit_with_cuda_stream(self, current_cuda_stream, task):
|
||||
CPUInfer.cpuinfer.submit_with_cuda_stream(current_cuda_stream, task)
|
||||
|
||||
def sync(self):
|
||||
CPUInfer.cpuinfer.sync()
|
||||
|
||||
def sync_with_cuda_stream(self, current_cuda_stream):
|
||||
CPUInfer.cpuinfer.sync_with_cuda_stream(current_cuda_stream)
|
||||
|
||||
|
||||
|
||||
def __getattribute__(self, __name: str) -> Any:
|
||||
return CPUInfer.cpu_infer.__getattribute__(__name)
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
return CPUInfer.cpu_infer.__setattr__(__name, __value)
|
||||
Loading…
Add table
Add a link
Reference in a new issue