mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 07:44:35 +00:00
Update readme; Format code; Add example yaml.
This commit is contained in:
parent
c38e77de6b
commit
e5b001d76f
8 changed files with 182 additions and 30 deletions
|
@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
||||||
|
|
||||||
<h2 id="Updates">🔥 Updates</h2>
|
<h2 id="Updates">🔥 Updates</h2>
|
||||||
|
|
||||||
|
* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
|
||||||
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
|
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
|
||||||
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
|
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
|
||||||
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
|
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
|
||||||
|
|
|
@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
||||||
|
|
||||||
<h2 id="Updates">🔥 Updates</h2>
|
<h2 id="Updates">🔥 Updates</h2>
|
||||||
|
|
||||||
|
* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
|
||||||
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
|
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
|
||||||
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
|
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
|
||||||
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).
|
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
- [Injection Tutorial](en/injection_tutorial.md)
|
- [Injection Tutorial](en/injection_tutorial.md)
|
||||||
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
|
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
|
||||||
- [Use FP8 GPU Kernel](en/fp8_kernel.md)
|
- [Use FP8 GPU Kernel](en/fp8_kernel.md)
|
||||||
|
- [Use AMD GPU](en/ROCm.md)
|
||||||
# Server
|
# Server
|
||||||
- [Server](en/api/server/server.md)
|
- [Server](en/api/server/server.md)
|
||||||
- [Website](en/api/server/website.md)
|
- [Website](en/api/server/website.md)
|
||||||
|
|
96
doc/en/ROCm.md
Normal file
96
doc/en/ROCm.md
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
# ROCm Support for ktransformers (Beta)
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
In our effort to expand GPU architecture support beyond NVIDIA, we are excited to introduce **AMD GPU support through ROCm** in ktransformers (Beta release). This implementation has been tested and developed using EPYC 9274F processors and AMD Radeon 7900xtx GPUs.
|
||||||
|
|
||||||
|
## Installation Guide
|
||||||
|
|
||||||
|
### 1. Install ROCm Driver
|
||||||
|
Begin by installing the ROCm drivers for your AMD GPU:
|
||||||
|
- [Official ROCm Installation Guide for Radeon GPUs](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-radeon.html)
|
||||||
|
|
||||||
|
### 2. Set Up Conda Environment
|
||||||
|
We recommend using Miniconda3/Anaconda3 for environment management:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download Miniconda
|
||||||
|
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||||
|
|
||||||
|
# Create environment
|
||||||
|
conda create --name ktransformers python=3.11
|
||||||
|
conda activate ktransformers
|
||||||
|
|
||||||
|
# Install required libraries
|
||||||
|
conda install -c conda-forge libstdcxx-ng
|
||||||
|
|
||||||
|
# Verify GLIBCXX version (should include 3.4.32)
|
||||||
|
strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
|
||||||
|
|
||||||
|
### 3. Install PyTorch for ROCm
|
||||||
|
Install PyTorch with ROCm 6.2.4 support:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip3 install torch torchvision torchaudio \
|
||||||
|
--index-url https://download.pytorch.org/whl/rocm6.2.4
|
||||||
|
pip3 install packaging ninja cpufeature numpy
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Tip:** For other ROCm versions, visit [PyTorch Previous Versions](https://pytorch.org/get-started/previous-versions/)
|
||||||
|
|
||||||
|
### 4. Build ktransformers
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone repository
|
||||||
|
git clone https://github.com/kvcache-ai/ktransformers.git
|
||||||
|
cd ktransformers
|
||||||
|
git submodule update --init
|
||||||
|
|
||||||
|
# Optional: Compile web interface
|
||||||
|
# See: api/server/website.md
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
bash install.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running DeepSeek-R1 Models
|
||||||
|
|
||||||
|
### Configuration for 24GB VRAM GPUs
|
||||||
|
Use our optimized configuration for constrained VRAM:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python ktransformers/local_chat.py \
|
||||||
|
--model_path deepseek-ai/DeepSeek-R1 \
|
||||||
|
--gguf_path <path_to_gguf_files> \
|
||||||
|
--optimize_config_path ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml \
|
||||||
|
--cpu_infer <cpu_cores + 1>
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Beta Note:** Current Q8 linear implementation (Marlin alternative) shows suboptimal performance. Expect optimizations in future releases.
|
||||||
|
|
||||||
|
### Configuration for 40GB+ VRAM GPUs
|
||||||
|
For better performance on high-VRAM GPUs:
|
||||||
|
|
||||||
|
1. Modify `DeepSeek-V3-Chat.yaml`:
|
||||||
|
```yaml
|
||||||
|
# Replace all instances of:
|
||||||
|
KLinearMarlin → KLinearTorch
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Execute with:
|
||||||
|
```bash
|
||||||
|
python ktransformers/local_chat.py \
|
||||||
|
--model_path deepseek-ai/DeepSeek-R1 \
|
||||||
|
--gguf_path <path_to_gguf_files> \
|
||||||
|
--optimize_config_path <modified_yaml_path> \
|
||||||
|
--cpu_infer <cpu_cores + 1>
|
||||||
|
```
|
||||||
|
> **Tip:** If you got 2 * 24GB AMD GPUS, you may also do the same modify and run `ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` instead.
|
||||||
|
|
||||||
|
## Known Limitations
|
||||||
|
- Marlin operations not supported on ROCm platform
|
||||||
|
- Current Q8 linear implementation shows reduced performance (Beta limitation)
|
|
@ -187,8 +187,6 @@ class KLinearQ8(KLinearBase):
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
orig_module: nn.Module = None,
|
orig_module: nn.Module = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
group_size: int = 128, # 增大分组大小,减少量化噪声
|
|
||||||
percentile: float = 99.99, # 新增:对异常值进行截断的百分位数
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||||
|
@ -199,8 +197,6 @@ class KLinearQ8(KLinearBase):
|
||||||
self.weight_zero_point = None
|
self.weight_zero_point = None
|
||||||
self.bias = None
|
self.bias = None
|
||||||
self.loaded = False
|
self.loaded = False
|
||||||
self.group_size = group_size
|
|
||||||
self.percentile = percentile
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
|
@ -246,16 +242,9 @@ class KLinearQ8(KLinearBase):
|
||||||
# For Q4, ensure the values stay within 4-bit range
|
# For Q4, ensure the values stay within 4-bit range
|
||||||
if bits == 4:
|
if bits == 4:
|
||||||
q_matrix = torch.clamp(q_matrix, -7, 7)
|
q_matrix = torch.clamp(q_matrix, -7, 7)
|
||||||
|
|
||||||
# Get matrix shape
|
|
||||||
rows, cols = q_matrix.shape
|
rows, cols = q_matrix.shape
|
||||||
|
|
||||||
# Convert to float32
|
|
||||||
dequant_matrix = q_matrix.to(torch.float32)
|
dequant_matrix = q_matrix.to(torch.float32)
|
||||||
|
|
||||||
# Create broadcasted scales: reshape scales to [1, cols] for broadcasting
|
|
||||||
scales_broadcast = scales.view(1, cols)
|
scales_broadcast = scales.view(1, cols)
|
||||||
|
|
||||||
# Apply dequantization to all columns at once using matrix multiplication
|
# Apply dequantization to all columns at once using matrix multiplication
|
||||||
dequant_matrix = dequant_matrix * scales_broadcast
|
dequant_matrix = dequant_matrix * scales_broadcast
|
||||||
|
|
||||||
|
@ -285,21 +274,14 @@ class KLinearQ8(KLinearBase):
|
||||||
|
|
||||||
# Determine quantization parameters based on bits
|
# Determine quantization parameters based on bits
|
||||||
if bits == 8:
|
if bits == 8:
|
||||||
# Q8: range is -127 to 127
|
|
||||||
max_int = 127
|
max_int = 127
|
||||||
qtype = torch.int8
|
qtype = torch.int8
|
||||||
elif bits == 4:
|
elif bits == 4:
|
||||||
# Q4: range is -7 to 7 (using 4-bit signed integers)
|
|
||||||
max_int = 7
|
max_int = 7
|
||||||
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range
|
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support
|
||||||
else:
|
else:
|
||||||
raise ValueError("Quantization bits must be either 8 or 4")
|
raise ValueError("Quantization bits must be either 8 or 4")
|
||||||
|
|
||||||
# Initialize results and scale factors
|
|
||||||
q_matrix = torch.zeros_like(matrix, dtype=qtype)
|
|
||||||
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
|
|
||||||
|
|
||||||
# Initialize scale factors
|
|
||||||
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
|
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
|
||||||
|
|
||||||
# Calculate max absolute value for each column
|
# Calculate max absolute value for each column
|
||||||
|
@ -370,13 +352,8 @@ class KLinearQ8(KLinearBase):
|
||||||
class KLinearFP8(KLinearBase):
|
class KLinearFP8(KLinearBase):
|
||||||
# this kernel requires special handling for weight
|
# this kernel requires special handling for weight
|
||||||
# Please load the weight file downloaded from KVCache.AI
|
# Please load the weight file downloaded from KVCache.AI
|
||||||
marlin_q_w: torch.Tensor
|
|
||||||
marlin_s: torch.Tensor
|
|
||||||
g_idx: torch.Tensor
|
|
||||||
sort_indices: torch.Tensor
|
|
||||||
has_bias: bool
|
has_bias: bool
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
scale_w: torch.Tensor
|
|
||||||
bias: torch.Tensor
|
bias: torch.Tensor
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
generate_op: "KLinearQ8"
|
generate_op: "KLinearMarlin"
|
||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
|
@ -24,7 +24,7 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
generate_op: "KLinearTorch"
|
generate_op: "KLinearMarlin"
|
||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
generate_op: "KLinearQ8"
|
generate_op: "KLinearMarlin"
|
||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
|
@ -23,9 +23,9 @@
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
generate_op: "KLinearCPUInfer"
|
generate_op: "KLinearMarlin"
|
||||||
prefill_op: "KLinearTorch"
|
prefill_op: "KLinearTorch"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\..*\\.mlp$"
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^lm_head$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearCPUInfer"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
generate_op: "KLinearQ8"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
- match:
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
Loading…
Add table
Add a link
Reference in a new issue