diff --git a/README.md b/README.md
index 76ad6eb..63728d2 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
diff --git a/doc/README.md b/doc/README.md
index 8bd94a0..f0683a3 100644
--- a/doc/README.md
+++ b/doc/README.md
@@ -22,6 +22,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
🔥 Updates
+* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) for DeepSeek-V3 and R1 in 24GB VRAM.
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. The detailed tutorial is [here](./en/DeepseekR1_V3_tutorial.md).
diff --git a/doc/SUMMARY.md b/doc/SUMMARY.md
index d9fa9b8..854549c 100644
--- a/doc/SUMMARY.md
+++ b/doc/SUMMARY.md
@@ -10,6 +10,7 @@
- [Injection Tutorial](en/injection_tutorial.md)
- [Multi-GPU Tutorial](en/multi-gpu-tutorial.md)
- [Use FP8 GPU Kernel](en/fp8_kernel.md)
+- [Use AMD GPU](en/ROCm.md)
# Server
- [Server](en/api/server/server.md)
- [Website](en/api/server/website.md)
diff --git a/doc/en/ROCm.md b/doc/en/ROCm.md
new file mode 100644
index 0000000..39f4890
--- /dev/null
+++ b/doc/en/ROCm.md
@@ -0,0 +1,96 @@
+# ROCm Support for ktransformers (Beta)
+
+## Introduction
+
+### Overview
+In our effort to expand GPU architecture support beyond NVIDIA, we are excited to introduce **AMD GPU support through ROCm** in ktransformers (Beta release). This implementation has been tested and developed using EPYC 9274F processors and AMD Radeon 7900xtx GPUs.
+
+## Installation Guide
+
+### 1. Install ROCm Driver
+Begin by installing the ROCm drivers for your AMD GPU:
+- [Official ROCm Installation Guide for Radeon GPUs](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-radeon.html)
+
+### 2. Set Up Conda Environment
+We recommend using Miniconda3/Anaconda3 for environment management:
+
+```bash
+# Download Miniconda
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
+
+# Create environment
+conda create --name ktransformers python=3.11
+conda activate ktransformers
+
+# Install required libraries
+conda install -c conda-forge libstdcxx-ng
+
+# Verify GLIBCXX version (should include 3.4.32)
+strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX
+```
+
+> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`
+
+### 3. Install PyTorch for ROCm
+Install PyTorch with ROCm 6.2.4 support:
+
+```bash
+pip3 install torch torchvision torchaudio \
+ --index-url https://download.pytorch.org/whl/rocm6.2.4
+pip3 install packaging ninja cpufeature numpy
+```
+
+> **Tip:** For other ROCm versions, visit [PyTorch Previous Versions](https://pytorch.org/get-started/previous-versions/)
+
+### 4. Build ktransformers
+
+```bash
+# Clone repository
+git clone https://github.com/kvcache-ai/ktransformers.git
+cd ktransformers
+git submodule update --init
+
+# Optional: Compile web interface
+# See: api/server/website.md
+
+# Install dependencies
+bash install.sh
+```
+
+## Running DeepSeek-R1 Models
+
+### Configuration for 24GB VRAM GPUs
+Use our optimized configuration for constrained VRAM:
+
+```bash
+python ktransformers/local_chat.py \
+ --model_path deepseek-ai/DeepSeek-R1 \
+ --gguf_path \
+ --optimize_config_path ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml \
+ --cpu_infer
+```
+
+> **Beta Note:** Current Q8 linear implementation (Marlin alternative) shows suboptimal performance. Expect optimizations in future releases.
+
+### Configuration for 40GB+ VRAM GPUs
+For better performance on high-VRAM GPUs:
+
+1. Modify `DeepSeek-V3-Chat.yaml`:
+ ```yaml
+ # Replace all instances of:
+ KLinearMarlin → KLinearTorch
+ ```
+
+2. Execute with:
+ ```bash
+ python ktransformers/local_chat.py \
+ --model_path deepseek-ai/DeepSeek-R1 \
+ --gguf_path \
+ --optimize_config_path \
+ --cpu_infer
+ ```
+> **Tip:** If you got 2 * 24GB AMD GPUS, you may also do the same modify and run `ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` instead.
+
+## Known Limitations
+- Marlin operations not supported on ROCm platform
+- Current Q8 linear implementation shows reduced performance (Beta limitation)
diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py
index 0d3fdb3..b05cd08 100644
--- a/ktransformers/operators/linear.py
+++ b/ktransformers/operators/linear.py
@@ -187,8 +187,6 @@ class KLinearQ8(KLinearBase):
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
- group_size: int = 128, # 增大分组大小,减少量化噪声
- percentile: float = 99.99, # 新增:对异常值进行截断的百分位数
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
@@ -199,8 +197,6 @@ class KLinearQ8(KLinearBase):
self.weight_zero_point = None
self.bias = None
self.loaded = False
- self.group_size = group_size
- self.percentile = percentile
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
@@ -246,16 +242,9 @@ class KLinearQ8(KLinearBase):
# For Q4, ensure the values stay within 4-bit range
if bits == 4:
q_matrix = torch.clamp(q_matrix, -7, 7)
-
- # Get matrix shape
rows, cols = q_matrix.shape
-
- # Convert to float32
dequant_matrix = q_matrix.to(torch.float32)
-
- # Create broadcasted scales: reshape scales to [1, cols] for broadcasting
scales_broadcast = scales.view(1, cols)
-
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix = dequant_matrix * scales_broadcast
@@ -285,21 +274,14 @@ class KLinearQ8(KLinearBase):
# Determine quantization parameters based on bits
if bits == 8:
- # Q8: range is -127 to 127
max_int = 127
qtype = torch.int8
elif bits == 4:
- # Q4: range is -7 to 7 (using 4-bit signed integers)
max_int = 7
- qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range
+ qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support
else:
raise ValueError("Quantization bits must be either 8 or 4")
-
- # Initialize results and scale factors
- q_matrix = torch.zeros_like(matrix, dtype=qtype)
- scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
-
- # Initialize scale factors
+
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
# Calculate max absolute value for each column
@@ -370,13 +352,8 @@ class KLinearQ8(KLinearBase):
class KLinearFP8(KLinearBase):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
- marlin_q_w: torch.Tensor
- marlin_s: torch.Tensor
- g_idx: torch.Tensor
- sort_indices: torch.Tensor
has_bias: bool
weight: torch.Tensor
- scale_w: torch.Tensor
bias: torch.Tensor
def __init__(
self,
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
index 3458090..c20973d 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
@@ -13,7 +13,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearQ8"
+ generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
@@ -24,7 +24,7 @@
kwargs:
generate_device: "cpu"
prefill_device: "cuda"
- generate_op: "KLinearTorch"
+ generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
index a92e988..d28e016 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
@@ -14,7 +14,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearQ8"
+ generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
@@ -23,9 +23,9 @@
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
- generate_device: "cpu"
+ generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearCPUInfer"
+ generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
diff --git a/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
new file mode 100644
index 0000000..628a952
--- /dev/null
+++ b/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml
@@ -0,0 +1,76 @@
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
+ replace:
+ class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+
+- match:
+ name: "^lm_head$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ generate_op: "KLinearCPUInfer"
+ prefill_op: "KLinearTorch"
+
+- match:
+ name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
+ class: torch.nn.Linear # only match modules matching name and class simultaneously
+ replace:
+ class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cuda"
+ generate_op: "KLinearQ8"
+ prefill_op: "KLinearTorch"
+- match:
+ name: "^model\\.layers\\..*\\.mlp$"
+ class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
+ replace:
+ class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+- match:
+ class: ktransformers.models.modeling_deepseek_v3.MoEGate
+ replace:
+ class: ktransformers.operators.gate.KMoEGate
+ kwargs:
+ generate_device: "cuda:0"
+ prefill_device: "cuda:0"
+- match:
+ name: "^model\\.layers\\..*\\.mlp\\.experts$"
+ replace:
+ class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
+ kwargs:
+ prefill_device: "cuda"
+ prefill_op: "KExpertsTorch"
+ generate_device: "cpu"
+ generate_op: "KExpertsCPU"
+ out_device: "cuda"
+ recursive: False # don't recursively inject submodules of this module
+- match:
+ name: "^model\\.layers\\..*\\.self_attn$"
+ replace:
+ class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
+ kwargs:
+ generate_device: "cuda"
+ prefill_device: "cuda"
+ absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
+- match:
+ name: "^model$"
+ replace:
+ class: "ktransformers.operators.models.KDeepseekV2Model"
+ kwargs:
+ per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
+- match:
+ name: "^model.embed_tokens"
+ replace:
+ class: "default"
+ kwargs:
+ generate_device: "cpu"
+ prefill_device: "cpu"
\ No newline at end of file