kvcache-ai-ktransformers/ktransformers/operators/cpuinfer.py
2025-03-10 22:03:04 +08:00

751 lines
26 KiB
Python

#!/usr/bin/env python
# coding=utf-8
"""
Description : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference
with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring
and managing key-value caches, updating and retrieving cache data, and handling attention
operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies
(e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization
on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration.
These classes facilitate efficient caching and memory management for deep learning models
that leverage key-value attention mechanisms, particularly on CPU-based systems.
Author : djw
Date : 2024-08-26 23:25:24
Version : 1.0.0
LastEditors : djw
LastEditTime : 2024-08-26 23:25:24
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import sys, os
from typing import Any
import torch
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.server.config.config import Config
class CPUInferKVCache:
def __init__(
self,
layer_num: int = 32,
kv_head_num: int = 8,
q_head_num: int = 32,
head_dim: int = 128,
block_len: int = 256,
anchor_num: int = 4,
anchor_type: str = "FIXED",
kv_type: str = "Q4_0",
retrieval_type: str = "SHARED",
layer_step: int = 1,
token_step: int = 1,
layer_offset: int = 0,
max_thread_num: int = 32,
max_batch_size: int = 4,
max_block_num: int = 512,
):
if anchor_type == "FIXED":
anchor_type = cpuinfer_ext.kvcache.AnchorType.FIXED
elif anchor_type == "QUEST":
anchor_type = cpuinfer_ext.kvcache.AnchorType.QUEST
elif anchor_type == "DYNAMIC":
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
elif anchor_type == "BLOCK_MEAN":
anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MEAN
elif anchor_type == "BLOCK_MAX":
anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MAX
else:
raise ValueError(f"Unknown anchor type: {anchor_type}")
if kv_type == "FP16":
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
elif kv_type == "FP32":
assert False, "FP32 is not supported yet."
kv_type = cpuinfer_ext.kvcache.ggml_type.FP32
elif kv_type == "Q4_0":
kv_type = cpuinfer_ext.kvcache.ggml_type.Q4_0
elif kv_type == "Q8_0":
kv_type = cpuinfer_ext.kvcache.ggml_type.Q8_0
else:
raise ValueError(f"Unknown kv type: {kv_type}")
if retrieval_type == "SHARED":
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
elif retrieval_type == "INDIVIDUAL":
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.QHEAD
elif retrieval_type == "SEPARATE":
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.KVHEAD
self.config = cpuinfer_ext.kvcache.KVCacheConfig(
layer_num,
kv_head_num,
q_head_num,
head_dim,
block_len,
anchor_num,
anchor_type,
kv_type,
retrieval_type,
layer_step,
token_step,
layer_offset,
max_block_num,
max_batch_size,
max_thread_num,
)
self.kvcache = cpuinfer_ext.kvcache.KVCache(self.config)
def load_kvcache(self, tensor_file_path: str):
if not os.path.exists(tensor_file_path):
raise FileNotFoundError(f"The file {tensor_file_path} does not exist.")
return self.kvcache.load_kvcache(tensor_file_path,)
def dump_kvcache(
self, block_table: torch.Tensor, cache_total_len: int, tensor_file_path: str
):
assert (
block_table.dim() == 1
and block_table.dtype == torch.int
and block_table.is_contiguous()
and block_table.device == torch.device("cpu")
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
block_table.dim(),
block_table.size(),
block_table.dtype,
block_table.is_contiguous(),
block_table.device,
)
assert (
cache_total_len > 0
and cache_total_len <= self.config.block_len * block_table.size(0)
), "cache_total_len: {}".format(cache_total_len)
if not os.path.exists(os.path.dirname(tensor_file_path)):
os.makedirs(os.path.dirname(tensor_file_path))
return self.kvcache.dump_kvcache(
block_table.data_ptr(),
cache_total_len,
tensor_file_path,
)
def update_cache_total_len(self, cache_total_len: int):
assert cache_total_len > 0, "cache_total_len: {}".format(cache_total_len)
self.kvcache.update_cache_total_len(cache_total_len)
# q_in: (bsz, q_len, q_head_num, head_dim)
# output: (bsz, q_len, q_head_num, head_dim)
# attn_lse: (bsz, q_len, q_head_num)
# block_table: (bsz, max_block_num)
def attn(
self,
q_in: torch.Tensor,
output: torch.Tensor,
attn_lse: torch.Tensor,
layer_idx: int,
generate_token_idx: int,
block_table: torch.Tensor | None = None,
cache_seqlens: torch.Tensor | None = None,
pick_block_num: int | None = None,
init_block_num: int | None = None,
local_block_num: int | None = None,
):
assert (
q_in.dim() == 4
and q_in.size(2) == self.config.q_head_num
and q_in.size(3) == self.config.head_dim
and q_in.dtype == torch.float16
and q_in.is_contiguous()
and q_in.device == torch.device("cpu")
), "q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
q_in.dim(), q_in.size(), q_in.dtype, q_in.is_contiguous(), q_in.device
)
batch_size = q_in.size(0)
q_len = q_in.size(1)
assert (block_table is None) or (
block_table.dim() == 2
and block_table.size(0) == batch_size
and block_table.dtype == torch.int
and block_table.is_contiguous()
and block_table.device == torch.device("cpu")
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
block_table.dim(),
block_table.size(),
block_table.dtype,
block_table.is_contiguous(),
block_table.device,
)
max_block_num = block_table.size(1) if block_table is not None else 0
assert (
output.dim() == 4
and output.size(0) == batch_size
and output.size(2) == self.config.q_head_num
and output.size(1) == q_len
and output.size(3) == self.config.head_dim
and output.dtype == torch.float16
and output.is_contiguous()
and output.device == torch.device("cpu")
), "output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
output.dim(),
output.size(),
output.dtype,
output.is_contiguous(),
output.device,
)
assert (
attn_lse.dim() == 3
and attn_lse.size(0) == batch_size
and attn_lse.size(1) == q_len
and attn_lse.size(2) == self.config.q_head_num
and attn_lse.dtype == torch.float32
and attn_lse.is_contiguous()
and attn_lse.device == torch.device("cpu")
), "attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
attn_lse.dim(),
attn_lse.size(),
attn_lse.dtype,
attn_lse.is_contiguous(),
attn_lse.device,
)
assert (
layer_idx >= 0 and layer_idx < self.config.layer_num
), "layer_idx: {}".format(layer_idx)
assert (cache_seqlens is None) or (
cache_seqlens.dim() == 1
and cache_seqlens.size(0) == batch_size
and cache_seqlens.dtype == torch.int
and cache_seqlens.is_contiguous()
and cache_seqlens.device == torch.device("cpu")
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
cache_seqlens.dim(),
cache_seqlens.size(),
cache_seqlens.dtype,
cache_seqlens.is_contiguous(),
cache_seqlens.device,
)
return self.kvcache.attn(
q_in.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
layer_idx,
generate_token_idx,
q_len,
batch_size,
max_block_num,
block_table.data_ptr() if block_table is not None else 0,
cache_seqlens.data_ptr() if cache_seqlens is not None else 0,
pick_block_num,
init_block_num,
local_block_num,
)
# k_in: (block_len, kv_head_num, head_dim)
# v_in: (block_len, kv_head_num, head_dim)
def update_kvcache_one_block_fp16(
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
):
assert (
k_in.dim() == 3
and k_in.size(1) == self.config.block_len
and k_in.size(0) == self.config.kv_head_num
and k_in.size(2) == self.config.head_dim
and k_in.dtype == torch.float16
and k_in.is_contiguous()
and k_in.device == torch.device("cpu")
), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
)
assert (
v_in.dim() == 3
and v_in.size(1) == self.config.block_len
and v_in.size(0) == self.config.kv_head_num
and v_in.size(2) == self.config.head_dim
and v_in.dtype == torch.float16
and v_in.is_contiguous()
and v_in.device == torch.device("cpu")
), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.update_one_block_fp16(
k_in.data_ptr(),
v_in.data_ptr(),
layer_id,
block_idx,
)
def get_kvcache_one_block_fp16(
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int
):
assert (
k_in.dim() == 3
and k_in.size(1) == self.config.block_len
and k_in.size(0) == self.config.kv_head_num
and k_in.size(2) == self.config.head_dim
and k_in.dtype == torch.float16
and k_in.is_contiguous()
and k_in.device == torch.device("cpu")
), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device
)
assert (
v_in.dim() == 3
and v_in.size(1) == self.config.block_len
and v_in.size(0) == self.config.kv_head_num
and v_in.size(2) == self.config.head_dim
and v_in.dtype == torch.float16
and v_in.is_contiguous()
and v_in.device == torch.device("cpu")
), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.get_one_block_fp16(
k_in.data_ptr(),
v_in.data_ptr(),
layer_id,
block_idx,
)
def update_importance_one_block(
self, importance: torch.Tensor, layer_id: int, block_idx: int
):
assert (
importance.dim() == 1
and importance.size(0) == self.config.block_len
and importance.dtype == torch.float16
and importance.is_contiguous()
and importance.device == torch.device("cpu")
), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
importance.dim(),
importance.size(),
importance.dtype,
importance.is_contiguous(),
importance.device,
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.update_importance_one_block(
importance.data_ptr(),
layer_id,
block_idx,
)
def get_importance_one_block(
self, importance: torch.Tensor, layer_id: int, block_idx: int
):
assert (
importance.dim() == 1
and importance.size(0) == self.config.block_len
and importance.dtype == torch.float16
and importance.is_contiguous()
and importance.device == torch.device("cpu")
), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
importance.dim(),
importance.size(),
importance.dtype,
importance.is_contiguous(),
importance.device,
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.get_importance_one_block(
importance.data_ptr(),
layer_id,
block_idx,
)
def get_anchor_one_block(self, anchor: torch.Tensor, layer_id: int, block_idx: int):
assert (
anchor.dim() == 3
and anchor.size(0) == self.config.kv_head_num
and anchor.size(1) == self.config.anchor_num
and anchor.size(2) == self.config.head_dim
and anchor.dtype == torch.float16
and anchor.is_contiguous()
and anchor.device == torch.device("cpu")
), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
anchor.dim(),
anchor.size(),
anchor.dtype,
anchor.is_contiguous(),
anchor.device,
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.get_anchor_one_block(
anchor.data_ptr(),
layer_id,
block_idx,
)
def update_anchor_one_block(
self, anchor: torch.Tensor, layer_id: int, block_idx: int
):
assert (
anchor.dim() == 3
and anchor.size(0) == self.config.kv_head_num
and anchor.size(1) == self.config.anchor_num
and anchor.size(2) == self.config.head_dim
and anchor.dtype == torch.float16
and anchor.is_contiguous()
and anchor.device == torch.device("cpu")
), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
anchor.dim(),
anchor.size(),
anchor.dtype,
anchor.is_contiguous(),
anchor.device,
)
assert (
layer_id >= 0 and layer_id < self.config.layer_num
), "layer_id: {}".format(layer_id)
assert block_idx >= 0, "block_idx: {}".format(block_idx)
return self.kvcache.update_anchor_one_block(
anchor.data_ptr(),
layer_id,
block_idx,
)
def calc_anchor_all_layers(
self,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
):
assert (
block_table.dim() == 2
and block_table.size(0) == cache_seqlens.size(0)
and block_table.dtype == torch.int
and block_table.is_contiguous()
and block_table.device == torch.device("cpu")
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
block_table.dim(),
block_table.size(),
block_table.dtype,
block_table.is_contiguous(),
block_table.device,
)
assert (
cache_seqlens.dim() == 1
and cache_seqlens.dtype == torch.int
and cache_seqlens.is_contiguous()
and cache_seqlens.device == torch.device("cpu")
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
cache_seqlens.dim(),
cache_seqlens.size(),
cache_seqlens.dtype,
cache_seqlens.is_contiguous(),
cache_seqlens.device,
)
batch_size = block_table.size(0)
max_block_num = block_table.size(1)
return self.kvcache.calc_anchor_all_layers(
block_table.data_ptr(),
cache_seqlens.data_ptr(),
batch_size,
max_block_num,
)
def clear_importance_all_layers(
self,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
):
assert (
block_table.dim() == 2
and block_table.size(0) == cache_seqlens.size(0)
and block_table.dtype == torch.int
and block_table.is_contiguous()
and block_table.device == torch.device("cpu")
), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
block_table.dim(),
block_table.size(),
block_table.dtype,
block_table.is_contiguous(),
block_table.device,
)
assert (
cache_seqlens.dim() == 1
and cache_seqlens.dtype == torch.int
and cache_seqlens.is_contiguous()
and cache_seqlens.device == torch.device("cpu")
), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format(
cache_seqlens.dim(),
cache_seqlens.size(),
cache_seqlens.dtype,
cache_seqlens.is_contiguous(),
cache_seqlens.device,
)
batch_size = block_table.size(0)
max_block_num = block_table.size(1)
return self.kvcache.clear_importance_all_layers(
block_table.data_ptr(),
cache_seqlens.data_ptr(),
batch_size,
max_block_num,
)
def get_cache_total_len(self):
return self.kvcache.get_cache_total_len()
def update_kvcache_q4(
self,
k_in: torch.Tensor,
k_scales: torch.Tensor,
v_in: torch.Tensor,
v_scales: torch.Tensor,
layer_id: int,
seq_offset: int | None = None,
seq_len: int | None = None,
block_table: torch.Tensor | None = None,
):
raise NotImplementedError
def update_kvcache_fp16(
self,
k_in: torch.Tensor,
v_in: torch.Tensor,
layer_idx,
block_table: torch.Tensor,
max_block_num,
past_len: torch.Tensor,
q_len,
):
batch_size = block_table.size(0)
return self.kvcache.get_kvcache_fp16(
k_in.data_ptr(),
v_in.data_ptr(),
layer_idx,
block_table.data_ptr(),
batch_size,
max_block_num,
past_len.data_ptr(),
q_len
)
def get_kvcache_q4(
self,
k_in: torch.Tensor,
k_scales: torch.Tensor,
v_in: torch.Tensor,
v_scales: torch.Tensor,
layer_id: int,
seq_offset: int | None = None,
seq_len: int | None = None,
block_table: torch.Tensor | None = None,
):
raise NotImplementedError
def get_kvcache_fp16(
self,
k_in: torch.Tensor,
v_in: torch.Tensor,
layer_id: int,
layer_idx,
block_table: torch.Tensor,
max_block_num,
past_len: torch.Tensor,
):
batch_size = block_table.size(0)
return self.kvcache.get_kvcache_fp16(
k_in.data_ptr(),
v_in.data_ptr(),
layer_idx,
block_table.data_ptr(),
batch_size,
max_block_num,
past_len.data_ptr(),
)
def get_and_update_kvcache_fp16(
self,
k_cache_cpu: torch.Tensor,
v_cache_cpu: torch.Tensor,
layer_idx,
block_table: torch.Tensor,
max_block_num,
past_len: torch.Tensor,
q_len,
):
batch_size = block_table.size(0)
return self.kvcache.get_and_update_kvcache_fp16(
k_cache_cpu.data_ptr(),
v_cache_cpu.data_ptr(),
layer_idx,
block_table.data_ptr(),
batch_size,
max_block_num,
past_len.data_ptr(),
q_len,
)
def update_importance(
self,
importance_cache: torch.Tensor,
layer_idx,
block_table: torch.Tensor,
max_block_num,
offset: torch.Tensor,
width,
):
batch_size = block_table.size(0)
return self.kvcache.update_importance(
importance_cache.data_ptr(),
layer_idx,
block_table.data_ptr(),
batch_size,
max_block_num,
offset.data_ptr(),
width,
)
# attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32)
def get_attn_sparsity(
self,
q_in: torch.Tensor,
attn_sparsity: torch.Tensor,
layer_idx: int,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
block_table_origin: torch.Tensor,
cache_seqlens_origin: torch.Tensor,
generate_token_idx: int = 0,
topk: int | None = None,
local: int | None = None,
):
batch_size = block_table.size(0)
max_block_num = block_table.size(1)
max_block_num_origin = block_table_origin.size(1)
q_len = q_in.size(1)
if topk is None or local is None or topk + local >= max_block_num:
topk = -1
local = -1
return self.kvcache.get_attn_sparsity(
q_in.data_ptr(),
attn_sparsity.data_ptr(),
layer_idx,
generate_token_idx,
q_len,
batch_size,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
block_table_origin.data_ptr(),
cache_seqlens_origin.data_ptr(),
max_block_num_origin,
topk,
local,
)
def attn_with_kvcache(
self,
q_in: torch.Tensor,
k_in: torch.Tensor,
v_in: torch.Tensor,
output: torch.Tensor,
attn_lse: torch.Tensor,
layer_idx: int,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
generate_token_idx: int = 0,
topk: int | None = None,
local: int | None = None,
):
batch_size = block_table.size(0)
max_block_num = block_table.size(1)
q_len = q_in.size(1)
if topk is None or local is None or topk + local >= max_block_num:
topk = -1
local = -1
return self.kvcache.attn_with_kvcache(
q_in.data_ptr(),
k_in.data_ptr(),
v_in.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
layer_idx,
generate_token_idx,
q_len,
batch_size,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
topk,
local,
)
def get_all_kvcache_one_layer(
self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int
):
return self.kvcache.get_all_kvcache_one_layer(
k_in.data_ptr(),
v_in.data_ptr(),
layer_id,
)
def get_importance(
self,
importance: torch.Tensor,
block_table: torch.Tensor,
):
raise NotImplementedError
def get_anchor(
self,
anchor: torch.Tensor,
block_table: torch.Tensor,
):
raise NotImplementedError
class CPUInfer:
cpuinfer = None
cur_backend_thread_num = 0
def __init__(self, thread_num):
if thread_num > CPUInfer.cur_backend_thread_num:
CPUInfer.cur_backend_thread_num = thread_num
del CPUInfer.cpuinfer
CPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num)
def submit(self, task):
CPUInfer.cpuinfer.submit(task)
def submit_with_cuda_stream(self, current_cuda_stream, task):
CPUInfer.cpuinfer.submit_with_cuda_stream(current_cuda_stream, task)
def sync(self):
CPUInfer.cpuinfer.sync()
def sync_with_cuda_stream(self, current_cuda_stream):
CPUInfer.cpuinfer.sync_with_cuda_stream(current_cuda_stream)