mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
⚡ ready to publish
This commit is contained in:
parent
f892d22849
commit
83401dbb3b
6 changed files with 157 additions and 19 deletions
|
@ -31,7 +31,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
|
|||
* **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu.
|
||||
* **Aug 9, 2024**: Support windows native.
|
||||
|
||||
<h2 id="show-cases">🔥 Show Cases</h2>
|
||||
<h2 id="show-cases">🌟 Show Cases</h2>
|
||||
|
||||
<div>
|
||||
<h3>GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
# GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM
|
||||
# SUMMARY
|
||||
|
||||
- **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup.<br>
|
||||
> **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup.<br>
|
||||
|
||||
Hi, we're the KTransformers team (formerly known for our local CPU/GPU hybrid inference open source project with DeepSeek-V2).
|
||||
|
||||
We've heard your requests for DeepSeek-R1/V3 support—and we're excited to finally deliver!
|
||||
Apologies for the wait, but we've been cooking up something truly amazing!
|
||||
|
||||
Today, we're proud to announce that we not only support DeepSeek-R1/V3, as showcased in the video below:
|
||||
|
||||
https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
|
||||
|
||||
|
@ -14,9 +21,10 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
|
|||
- Decode Speed (tokens/s):
|
||||
- KTransfermor: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only)
|
||||
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.
|
||||
- Upcoming Open Source Release:
|
||||
- AMX optimizations and selective expert activation will be open-sourced in V0.3.
|
||||
- Currently available only in preview binary distribution, which can be found [here](xxx).
|
||||
|
||||
|
||||
But we're also previewing our upcoming optimizations, including an Intel AMX-accelerated kernel and a selective expert activation method, which will significantly enhance performance. With V0.3-preview, we achieve up to 286 tokens/s for prefill, making it up to **64× faster than llama.cpp** for local inference.
|
||||
The binary distribution is available now and the source code will come ASAP! Check out the details [here](xxx)
|
||||
|
||||
|
||||
## Prerequisites
|
||||
|
@ -98,11 +106,32 @@ python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path
|
|||
<when you see chat, then press enter to load the text prompt_file>
|
||||
```
|
||||
The parameters' meaning is the same. But As we use dual socket, we set cpu_infer to 65
|
||||
|
||||
### V0.3 Showcase
|
||||
#### Dual socket version (64 cores)
|
||||
Our local_chat test command is:
|
||||
``` shell
|
||||
python -m ktransformers.local_chat --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 65 --cache_lens 1536
|
||||
<when you see chat, then press enter to load the text prompt_file>
|
||||
```
|
||||
The parameters' meaning is the same with V0.2. But As we use dual socket, we set cpu_infer to 65
|
||||
|
||||
## Some Explanations
|
||||
1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu.
|
||||
To avoid the cost of data transfer between nodes, we "copy" the critical matrix on
|
||||
both nodes which takes more memory consumption but accelerates the prefill and decoding process.
|
||||
But this method takes huge memory and slow when loading weights, So be patient when loading
|
||||
and monitor the memory usage. (we are considering to make this method as an option). We are going to optimize this huge memory overhead. Stay tuned~ <br>
|
||||
and monitor the memory usage. We are going to optimize this huge memory overhead. Stay tuned~ <br>
|
||||
2. The command args `--cpu_infer 65` specifies how many cores to use (it's ok that it exceeds the physical number,
|
||||
but it's not the more the better. Adjust it slightly lower to your actual number of cores)<br>
|
||||
|
||||
3. Why CPU/GPU Hybrid Inference?
|
||||
DeepSeek's MLA operators are highly computationally intensive. While running everything on CPU is possible, offloading the heavy computations to the GPU results in a massive performance boost.
|
||||
|
||||
4. Where Does the Speedup Come From?
|
||||
|
||||
- Expert Offload: Unlike traditional layer-based or KVCache offloading (as seen in llama.cpp), we offload the expert computation to the CPU and MLA/KVCache to GPU, aligning perfectly with DeepSeek’s architecture for optimal efficiency.
|
||||
- Intel AMX Optimization – Our AMX-accelerated kernel is meticulously tuned, running several times faster than existing llama.cpp implementations. We plan to open-source this kernel after cleansing and are considering upstream contributions to llama.cpp.
|
||||
|
||||
5. Why Intel CPUs?
|
||||
Intel is currently the only CPU vendor that supports AMX-like instructions, which delivers significantly better performance compared to AVX-only alternatives.
|
|
@ -18,6 +18,9 @@ from ktransformers.models.modeling_deepseek_v3 import (
|
|||
from ktransformers.models.modeling_deepseek import (
|
||||
DeepseekV2YarnRotaryEmbedding,
|
||||
DeepseekV2RotaryEmbedding,
|
||||
yarn_get_mscale,
|
||||
yarn_linear_ramp_mask,
|
||||
yarn_find_correction_range
|
||||
)
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
|
@ -188,7 +191,33 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
|||
self.orig_module.mscale_all_dim,
|
||||
)
|
||||
|
||||
class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
|
||||
# class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
|
||||
# def __init__(
|
||||
# self,
|
||||
# key: str,
|
||||
# gguf_loader: GGUFLoader,
|
||||
# config: PretrainedConfig,
|
||||
# orig_module: nn.Module,
|
||||
# # device: str = "cuda",
|
||||
# generate_device: str = "cuda",
|
||||
# prefill_device: str = "cuda",
|
||||
# **kwargs,
|
||||
# ):
|
||||
# BaseInjectedModule.__init__(
|
||||
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
# )
|
||||
# self.generate_device = generate_device
|
||||
# self.prefill_device = prefill_device
|
||||
|
||||
# def load(self):
|
||||
# # TODO support perlayer prefill
|
||||
# self.orig_module.__init__(
|
||||
# self.config,
|
||||
# device=self.generate_device
|
||||
# )
|
||||
# return
|
||||
|
||||
class YarnRotaryEmbeddingV3(BaseInjectedModule):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
|
@ -207,12 +236,92 @@ class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbeddin
|
|||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
# TODO support perlayer prefill
|
||||
self.orig_module.__init__(
|
||||
self.config,
|
||||
device=self.generate_device
|
||||
kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self._init(
|
||||
dim=self.config.qk_rope_head_dim,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
base=self.config.rope_theta,
|
||||
device=self.device,
|
||||
scaling_factor=self.config.rope_scaling["factor"],
|
||||
**kwargs,
|
||||
)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()* self._mscale
|
||||
sin = emb.sin()* self._mscale
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
def _init(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
original_max_position_embeddings=4096,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
mscale=1,
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.mscale = mscale
|
||||
self.mscale_all_dim = mscale_all_dim
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
freq_extra = 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
||||
)
|
||||
freq_inter = 1.0 / (
|
||||
self.scaling_factor
|
||||
* self.base
|
||||
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
||||
)
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.original_max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
||||
device=device, dtype=torch.float32
|
||||
)
|
||||
self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self._mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, self.mscale)
|
||||
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
||||
)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
|
||||
class DynamicNTKScalingRotaryEmbedding(
|
||||
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
@ -18,7 +18,7 @@
|
|||
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
@ -18,7 +18,7 @@
|
|||
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
|
Loading…
Add table
Reference in a new issue