mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
support AMX
This commit is contained in:
parent
b90362b5e6
commit
f3d842a0ca
15 changed files with 1799 additions and 62 deletions
5
.clang-format
Normal file
5
.clang-format
Normal file
|
@ -0,0 +1,5 @@
|
|||
---
|
||||
BasedOnStyle: LLVM
|
||||
ColumnLimit: 120 # 设置最大行宽为 100
|
||||
IndentWidth: 2
|
||||
---
|
|
@ -293,9 +293,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
|
|||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
|
||||
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
|
||||
|
||||
|
||||
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5})
|
||||
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
|
||||
|
||||
file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h")
|
||||
|
||||
|
|
107
csrc/ktransformers_ext/bench/bench_moe_amx.py
Normal file
107
csrc/ktransformers_ext/bench/bench_moe_amx.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
'''
|
||||
Description :
|
||||
Author : chenht2022
|
||||
Date : 2025-04-25 18:28:12
|
||||
Version : 1.0.0
|
||||
LastEditors : chenht2022
|
||||
LastEditTime : 2025-04-25 18:28:12
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
import os, sys
|
||||
import time
|
||||
sys.path.append(os.path.dirname(__file__) + '/../build')
|
||||
import cpuinfer_ext
|
||||
import torch
|
||||
|
||||
expert_num = 8
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
max_len = 25600
|
||||
n_routed_experts = 8
|
||||
layer_num = 10
|
||||
qlen = 1024
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(65)
|
||||
warm_up_iter = 100
|
||||
test_iter = 100
|
||||
|
||||
def bench_moe(quant_mode: str):
|
||||
with torch.inference_mode(mode=True):
|
||||
if quant_mode == "bf16":
|
||||
bytes_per_elem = 2.000000
|
||||
elif quant_mode == "int8":
|
||||
bytes_per_elem = 1.000000
|
||||
else:
|
||||
assert(False)
|
||||
|
||||
|
||||
moes = []
|
||||
gate_projs = []
|
||||
up_projs = []
|
||||
down_projs = []
|
||||
for _ in range(layer_num):
|
||||
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
|
||||
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
|
||||
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
|
||||
config = cpuinfer_ext.moe.AMX_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr())
|
||||
if quant_mode == "bf16":
|
||||
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights())
|
||||
CPUInfer.sync()
|
||||
elif quant_mode == "int8":
|
||||
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights())
|
||||
CPUInfer.sync()
|
||||
gate_projs.append(gate_proj)
|
||||
up_projs.append(up_proj)
|
||||
down_projs.append(down_proj)
|
||||
moes.append(moe)
|
||||
expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous()
|
||||
weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
|
||||
input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
|
||||
output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
|
||||
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
|
||||
|
||||
# warm up
|
||||
for i in range(warm_up_iter):
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward(
|
||||
qlen,
|
||||
n_routed_experts,
|
||||
expert_ids[i % layer_num].data_ptr(),
|
||||
weights[i % layer_num].data_ptr(),
|
||||
input[i % layer_num].data_ptr(),
|
||||
output[i % layer_num].data_ptr(),
|
||||
qlen_tensor.data_ptr()
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
# test
|
||||
start = time.perf_counter()
|
||||
for i in range(test_iter):
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward(
|
||||
qlen,
|
||||
n_routed_experts,
|
||||
expert_ids[i % layer_num].data_ptr(),
|
||||
weights[i % layer_num].data_ptr(),
|
||||
input[i % layer_num].data_ptr(),
|
||||
output[i % layer_num].data_ptr(),
|
||||
qlen_tensor.data_ptr()
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
end = time.perf_counter()
|
||||
total_time = end - start
|
||||
print('Quant mode: ', quant_mode)
|
||||
print('Time(s): ', total_time)
|
||||
print('Iteration: ', test_iter)
|
||||
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
|
||||
print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
|
||||
print('Flops: ', hidden_size * intermediate_size * qlen * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GFLOPS')
|
||||
print('')
|
||||
|
||||
bench_moe("bf16")
|
||||
bench_moe("int8")
|
|
@ -30,7 +30,8 @@ void SharedMemBuffer::alloc(void* object, std::vector<std::pair<void**, uint64_t
|
|||
if (buffer_) {
|
||||
free(buffer_);
|
||||
}
|
||||
buffer_ = malloc(size);
|
||||
buffer_ = std::aligned_alloc(64, size);
|
||||
|
||||
size_ = size;
|
||||
for (auto& obj_requests : hist_requests_) {
|
||||
for (auto& requests : obj_requests.second) {
|
37
csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.h
Normal file
37
csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.h
Normal file
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* @Description :
|
||||
* @Author : chenht2022
|
||||
* @Date : 2024-08-05 04:49:08
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : chenht2022
|
||||
* @LastEditTime : 2024-08-05 06:36:41
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
|
||||
#ifndef CPUINFER_SHAREDMEMBUFFER_H
|
||||
#define CPUINFER_SHAREDMEMBUFFER_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
class SharedMemBuffer {
|
||||
public:
|
||||
SharedMemBuffer();
|
||||
~SharedMemBuffer();
|
||||
|
||||
void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);
|
||||
void dealloc(void* object);
|
||||
|
||||
private:
|
||||
void* buffer_;
|
||||
uint64_t size_;
|
||||
std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;
|
||||
|
||||
void arrange(std::vector<std::pair<void**, uint64_t>> requests);
|
||||
};
|
||||
|
||||
static SharedMemBuffer shared_mem_buffer;
|
||||
|
||||
#endif
|
|
@ -17,6 +17,7 @@
|
|||
#include "operators/llamafile/linear.h"
|
||||
#include "operators/llamafile/mlp.h"
|
||||
#include "operators/llamafile/moe.h"
|
||||
#include "operators/amx/moe.hpp"
|
||||
#include "pybind11/functional.h"
|
||||
#include "pybind11/operators.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
@ -563,6 +564,75 @@ class MOEBindings {
|
|||
};
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class AMX_MOEBindings {
|
||||
public:
|
||||
class WarmUpBindings {
|
||||
public:
|
||||
struct Args {
|
||||
CPUInfer *cpuinfer;
|
||||
AMX_MOE<T> *moe;
|
||||
};
|
||||
static void inner(void *args) {
|
||||
Args *args_ = (Args *)args;
|
||||
args_->cpuinfer->enqueue(&AMX_MOE<T>::warm_up, args_->moe);
|
||||
}
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {
|
||||
Args *args = new Args{nullptr, &moe};
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
};
|
||||
class LoadWeightsBindings {
|
||||
public:
|
||||
struct Args {
|
||||
CPUInfer *cpuinfer;
|
||||
AMX_MOE<T> *moe;
|
||||
};
|
||||
static void inner(void *args) {
|
||||
Args *args_ = (Args *)args;
|
||||
args_->cpuinfer->enqueue(&AMX_MOE<T>::load_weights, args_->moe);
|
||||
}
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {
|
||||
Args *args = new Args{nullptr, &moe};
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
};
|
||||
class ForwardBindings {
|
||||
public:
|
||||
struct Args {
|
||||
CPUInfer *cpuinfer;
|
||||
AMX_MOE<T> *moe;
|
||||
int qlen;
|
||||
int k;
|
||||
const uint64_t *expert_ids;
|
||||
const float *weights;
|
||||
const void *input;
|
||||
void *output;
|
||||
int *batch_size_tensor;
|
||||
};
|
||||
static void inner(void *args) {
|
||||
Args *args_ = (Args *)args;
|
||||
args_->cpuinfer->enqueue(
|
||||
&AMX_MOE<T>::forward, args_->moe, args_->qlen, args_->k,
|
||||
args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor);
|
||||
}
|
||||
static std::pair<intptr_t, intptr_t>
|
||||
cpuinfer_interface(AMX_MOE<T> &moe, int qlen, int k, intptr_t expert_ids,
|
||||
intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) {
|
||||
Args *args = new Args{nullptr,
|
||||
&moe,
|
||||
qlen,
|
||||
k,
|
||||
(const uint64_t *)expert_ids,
|
||||
(const float *)weights,
|
||||
(const void *)input,
|
||||
(void *)output,
|
||||
(int *)batch_size_tensor};
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(cpuinfer_ext, m) {
|
||||
py::class_<CPUInfer>(m, "CPUInfer")
|
||||
.def(py::init<int>())
|
||||
|
@ -621,6 +691,27 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
|
|||
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
|
||||
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
|
||||
|
||||
py::class_<AMX_MOEConfig>(moe_module, "AMX_MOEConfig")
|
||||
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size,
|
||||
int intermediate_size,
|
||||
int max_len, intptr_t gate_proj,
|
||||
intptr_t up_proj, intptr_t down_proj) {
|
||||
return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,
|
||||
intermediate_size,
|
||||
max_len, (void *)gate_proj,
|
||||
(void *)up_proj, (void *)down_proj);
|
||||
}));
|
||||
py::class_<AMX_MOE<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE")
|
||||
.def(py::init<AMX_MOEConfig>())
|
||||
.def("warm_up", &AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights", &AMX_MOEBindings<amx::GemmKernel224BF>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward", &AMX_MOEBindings<amx::GemmKernel224BF>::ForwardBindings::cpuinfer_interface);
|
||||
py::class_<AMX_MOE<amx::GemmKernel224Int8>>(moe_module, "AMXInt8_MOE")
|
||||
.def(py::init<AMX_MOEConfig>())
|
||||
.def("warm_up", &AMX_MOEBindings<amx::GemmKernel224Int8>::WarmUpBindings::cpuinfer_interface)
|
||||
.def("load_weights", &AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface)
|
||||
.def("forward", &AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface);
|
||||
|
||||
auto kvcache_module = m.def_submodule("kvcache");
|
||||
|
||||
py::enum_<AnchorType>(kvcache_module, "AnchorType")
|
||||
|
|
974
csrc/ktransformers_ext/operators/amx/la/amx.hpp
Normal file
974
csrc/ktransformers_ext/operators/amx/la/amx.hpp
Normal file
|
@ -0,0 +1,974 @@
|
|||
/**
|
||||
* @Description :
|
||||
* @Author : chenht2022
|
||||
* @Date : 2025-04-25 18:28:12
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : chenht2022
|
||||
* @LastEditTime : 2025-04-25 18:28:12
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
|
||||
#pragma once
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <immintrin.h>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <stdlib.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "utils.hpp"
|
||||
#include <memory>
|
||||
|
||||
#if (defined(_WIN32) || defined(_WIN64))
|
||||
#define RESTRICT __restrict
|
||||
#else
|
||||
#define RESTRICT __restrict__
|
||||
#endif
|
||||
|
||||
#if (defined(_WIN32) || defined(_WIN64))
|
||||
#define ALWAYS_INLINE __forceinline
|
||||
#elif __has_attribute(always_inline) || defined(__GNUC__)
|
||||
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
||||
#else
|
||||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
namespace amx {
|
||||
|
||||
#define ARCH_GET_XCOMP_PERM 0x1022
|
||||
#define ARCH_REQ_XCOMP_PERM 0x1023
|
||||
#define XFEATURE_XTILECFG 17
|
||||
#define XFEATURE_XTILEDATA 18
|
||||
|
||||
const int TMMCount = 8;
|
||||
const int MaxTileHeight = 16;
|
||||
const int MaxTileWidth = 64;
|
||||
|
||||
const int AMX_BLK_SIZE = 32;
|
||||
|
||||
#define TMM0 0
|
||||
#define TMM1 1
|
||||
#define TMM2 2
|
||||
#define TMM3 3
|
||||
#define TMM4 4
|
||||
#define TMM5 5
|
||||
#define TMM6 6
|
||||
#define TMM7 7
|
||||
|
||||
inline bool enable_amx() {
|
||||
static thread_local bool initialized = false;
|
||||
if (initialized) {
|
||||
return true;
|
||||
}
|
||||
initialized = true;
|
||||
|
||||
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
|
||||
printf("\n Fail to do XFEATURE_XTILEDATA \n\n");
|
||||
return false;
|
||||
} else {
|
||||
// printf("\n TILE DATA USE SET - OK \n\n");
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct alignas(64) TileConfig {
|
||||
uint8_t palette;
|
||||
uint8_t start_row;
|
||||
std::array<uint8_t, 14> __0 = {};
|
||||
std::array<uint16_t, 8> colsb;
|
||||
std::array<uint8_t, 16> __1 = {};
|
||||
std::array<uint8_t, 8> rows;
|
||||
std::array<uint8_t, 8> __2 = {};
|
||||
|
||||
TileConfig() {
|
||||
palette = 1;
|
||||
start_row = 0;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
set_row_col(i, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void set_row_col(int i, uint8_t row, uint16_t col) {
|
||||
colsb[i] = col;
|
||||
rows[i] = row;
|
||||
}
|
||||
|
||||
void set_config() { _tile_loadconfig(this); }
|
||||
|
||||
static void load_data(int to, void *from, size_t stride) {
|
||||
switch (to) {
|
||||
case 0:
|
||||
_tile_loadd(0, from, stride);
|
||||
break;
|
||||
case 1:
|
||||
_tile_loadd(1, from, stride);
|
||||
break;
|
||||
case 2:
|
||||
_tile_loadd(2, from, stride);
|
||||
break;
|
||||
case 3:
|
||||
_tile_loadd(3, from, stride);
|
||||
break;
|
||||
case 4:
|
||||
_tile_loadd(4, from, stride);
|
||||
break;
|
||||
case 5:
|
||||
_tile_loadd(5, from, stride);
|
||||
break;
|
||||
case 6:
|
||||
_tile_loadd(6, from, stride);
|
||||
break;
|
||||
case 7:
|
||||
_tile_loadd(7, from, stride);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("no such tile");
|
||||
}
|
||||
}
|
||||
|
||||
static void store_data(int from, void *to, size_t stride) {
|
||||
switch (from) {
|
||||
case 0:
|
||||
_tile_stored(0, to, stride);
|
||||
break;
|
||||
case 1:
|
||||
_tile_stored(1, to, stride);
|
||||
break;
|
||||
case 2:
|
||||
_tile_stored(2, to, stride);
|
||||
break;
|
||||
case 3:
|
||||
_tile_stored(3, to, stride);
|
||||
break;
|
||||
case 4:
|
||||
_tile_stored(4, to, stride);
|
||||
break;
|
||||
case 5:
|
||||
_tile_stored(5, to, stride);
|
||||
break;
|
||||
case 6:
|
||||
_tile_stored(6, to, stride);
|
||||
break;
|
||||
case 7:
|
||||
_tile_stored(7, to, stride);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("no such tile");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(sizeof(TileConfig) == 64);
|
||||
|
||||
inline void debug_tile(int t) {
|
||||
printf("Tile %d\n", t);
|
||||
uint8_t data[16][64] = {};
|
||||
TileConfig::store_data(t, data, 64);
|
||||
for (int i = 0; i < 16; i++) {
|
||||
for (int j = 0; j < 64; j++) {
|
||||
printf("%3d ", data[i][j]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
inline void debug_tiles(int to = 8) {
|
||||
for (int i = 0; i < to; i++) {
|
||||
debug_tile(i);
|
||||
}
|
||||
}
|
||||
|
||||
inline void debug_m512(__m512 x) {
|
||||
float data[16];
|
||||
_mm512_storeu_ps(data, x);
|
||||
for (int i = 0; i < 16; i++) {
|
||||
printf("%f ", data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// transpose utils
|
||||
inline void transpose_16x16_32bit(__m512i *v) {
|
||||
__m512i v1[16];
|
||||
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
|
||||
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
|
||||
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
|
||||
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
|
||||
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
|
||||
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
|
||||
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
|
||||
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
|
||||
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
|
||||
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
|
||||
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
|
||||
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
|
||||
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
|
||||
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
|
||||
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
|
||||
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
|
||||
|
||||
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
|
||||
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
|
||||
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
|
||||
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
|
||||
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
|
||||
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
|
||||
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
|
||||
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
|
||||
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
|
||||
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
|
||||
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
|
||||
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
|
||||
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
|
||||
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
|
||||
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
|
||||
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
|
||||
|
||||
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
|
||||
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
|
||||
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
|
||||
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
|
||||
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
|
||||
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
|
||||
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
|
||||
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
|
||||
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
|
||||
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
|
||||
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
|
||||
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
|
||||
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
|
||||
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
|
||||
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
|
||||
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
|
||||
|
||||
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
|
||||
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
|
||||
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
|
||||
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
|
||||
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
|
||||
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
|
||||
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
|
||||
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
|
||||
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
|
||||
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
|
||||
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
|
||||
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
|
||||
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
|
||||
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
|
||||
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
|
||||
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
|
||||
}
|
||||
|
||||
/*
|
||||
Transpose 16x16 32-bit elements
|
||||
Note that v must be 64 byte aligned
|
||||
*/
|
||||
inline void transpose_16x16_32bit(__m512i *v, size_t stride) {
|
||||
assert(reinterpret_cast<intptr_t>(v) % 64 == 0 && "v must be 64 aligned");
|
||||
|
||||
auto stride_v = [=](int i) { return offset_pointer(v, i * stride); };
|
||||
__m512i v1[16];
|
||||
|
||||
v1[0] = _mm512_unpacklo_epi32(*stride_v(0), *stride_v(1));
|
||||
v1[1] = _mm512_unpackhi_epi32(*stride_v(0), *stride_v(1));
|
||||
v1[2] = _mm512_unpacklo_epi32(*stride_v(2), *stride_v(3));
|
||||
v1[3] = _mm512_unpackhi_epi32(*stride_v(2), *stride_v(3));
|
||||
v1[4] = _mm512_unpacklo_epi32(*stride_v(4), *stride_v(5));
|
||||
v1[5] = _mm512_unpackhi_epi32(*stride_v(4), *stride_v(5));
|
||||
v1[6] = _mm512_unpacklo_epi32(*stride_v(6), *stride_v(7));
|
||||
v1[7] = _mm512_unpackhi_epi32(*stride_v(6), *stride_v(7));
|
||||
v1[8] = _mm512_unpacklo_epi32(*stride_v(8), *stride_v(9));
|
||||
v1[9] = _mm512_unpackhi_epi32(*stride_v(8), *stride_v(9));
|
||||
v1[10] = _mm512_unpacklo_epi32(*stride_v(10), *stride_v(11));
|
||||
v1[11] = _mm512_unpackhi_epi32(*stride_v(10), *stride_v(11));
|
||||
v1[12] = _mm512_unpacklo_epi32(*stride_v(12), *stride_v(13));
|
||||
v1[13] = _mm512_unpackhi_epi32(*stride_v(12), *stride_v(13));
|
||||
v1[14] = _mm512_unpacklo_epi32(*stride_v(14), *stride_v(15));
|
||||
v1[15] = _mm512_unpackhi_epi32(*stride_v(14), *stride_v(15));
|
||||
|
||||
*stride_v(0) = _mm512_unpacklo_epi64(v1[0], v1[2]);
|
||||
*stride_v(1) = _mm512_unpackhi_epi64(v1[0], v1[2]);
|
||||
*stride_v(2) = _mm512_unpacklo_epi64(v1[1], v1[3]);
|
||||
*stride_v(3) = _mm512_unpackhi_epi64(v1[1], v1[3]);
|
||||
*stride_v(4) = _mm512_unpacklo_epi64(v1[4], v1[6]);
|
||||
*stride_v(5) = _mm512_unpackhi_epi64(v1[4], v1[6]);
|
||||
*stride_v(6) = _mm512_unpacklo_epi64(v1[5], v1[7]);
|
||||
*stride_v(7) = _mm512_unpackhi_epi64(v1[5], v1[7]);
|
||||
*stride_v(8) = _mm512_unpacklo_epi64(v1[8], v1[10]);
|
||||
*stride_v(9) = _mm512_unpackhi_epi64(v1[8], v1[10]);
|
||||
*stride_v(10) = _mm512_unpacklo_epi64(v1[9], v1[11]);
|
||||
*stride_v(11) = _mm512_unpackhi_epi64(v1[9], v1[11]);
|
||||
*stride_v(12) = _mm512_unpacklo_epi64(v1[12], v1[14]);
|
||||
*stride_v(13) = _mm512_unpackhi_epi64(v1[12], v1[14]);
|
||||
*stride_v(14) = _mm512_unpacklo_epi64(v1[13], v1[15]);
|
||||
*stride_v(15) = _mm512_unpackhi_epi64(v1[13], v1[15]);
|
||||
|
||||
v1[0] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0x88);
|
||||
v1[1] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0x88);
|
||||
v1[2] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0x88);
|
||||
v1[3] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0x88);
|
||||
v1[4] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0xdd);
|
||||
v1[5] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0xdd);
|
||||
v1[6] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0xdd);
|
||||
v1[7] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0xdd);
|
||||
v1[8] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0x88);
|
||||
v1[9] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0x88);
|
||||
v1[10] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0x88);
|
||||
v1[11] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0x88);
|
||||
v1[12] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0xdd);
|
||||
v1[13] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0xdd);
|
||||
v1[14] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0xdd);
|
||||
v1[15] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0xdd);
|
||||
|
||||
*stride_v(0) = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
|
||||
*stride_v(1) = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
|
||||
*stride_v(2) = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
|
||||
*stride_v(3) = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
|
||||
*stride_v(4) = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
|
||||
*stride_v(5) = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
|
||||
*stride_v(6) = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
|
||||
*stride_v(7) = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
|
||||
*stride_v(8) = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
|
||||
*stride_v(9) = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
|
||||
*stride_v(10) = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
|
||||
*stride_v(11) = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
|
||||
*stride_v(12) = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
|
||||
*stride_v(13) = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
|
||||
*stride_v(14) = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
|
||||
*stride_v(15) = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
|
||||
}
|
||||
|
||||
struct GemmKernel224BF {
|
||||
using dt = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
static const int TILE_M = 16;
|
||||
static const int TILE_K = 32;
|
||||
static const int TILE_N = 16;
|
||||
static const int VNNI_BLK = 2;
|
||||
|
||||
static const int M_STEP = TILE_M * 2;
|
||||
static const int N_STEP = TILE_N * 2;
|
||||
static const int K_STEP = TILE_K;
|
||||
|
||||
static inline const int N_BLOCK = 256;
|
||||
static inline const int K_BLOCK = 1792;
|
||||
|
||||
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
|
||||
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
int n_start = N_BLOCK * ith;
|
||||
int n_end = std::min(n, N_BLOCK * (ith + 1));
|
||||
return {n_start, n_end};
|
||||
}
|
||||
|
||||
static void config() {
|
||||
enable_amx();
|
||||
TileConfig tile_config;
|
||||
|
||||
// size is 16 x 32
|
||||
for (int i = 0; i < 2; i++)
|
||||
tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));
|
||||
|
||||
// size is 16 x 32
|
||||
for (int i = 2; i < 4; i++)
|
||||
tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));
|
||||
|
||||
// size is 16 x 16
|
||||
for (int i = 4; i < 8; i++)
|
||||
tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));
|
||||
|
||||
tile_config.set_config();
|
||||
}
|
||||
|
||||
static void load_a(dt *a, size_t lda) {
|
||||
_tile_loadd(0, a, lda);
|
||||
_tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);
|
||||
}
|
||||
|
||||
static void load_b(dt *b, size_t ldb) {
|
||||
_tile_loadd(2, b, ldb);
|
||||
_tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);
|
||||
}
|
||||
|
||||
static void clean_c() {
|
||||
_tile_zero(4);
|
||||
_tile_zero(5);
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
static void load_c(output_t *c, size_t ldc) {
|
||||
_tile_loadd(4, c, ldc);
|
||||
_tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);
|
||||
_tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);
|
||||
_tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);
|
||||
}
|
||||
|
||||
static void store_c(output_t *c, size_t ldc) {
|
||||
_tile_stored(4, c, ldc);
|
||||
_tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);
|
||||
_tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);
|
||||
_tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);
|
||||
}
|
||||
|
||||
static void run_tile() {
|
||||
_tile_dpbf16ps(4, 0, 2);
|
||||
_tile_dpbf16ps(5, 0, 3);
|
||||
_tile_dpbf16ps(6, 1, 2);
|
||||
_tile_dpbf16ps(7, 1, 3);
|
||||
}
|
||||
|
||||
struct BufferA {
|
||||
ggml_bf16_t *a;
|
||||
int max_m, k;
|
||||
|
||||
static size_t required_size(int max_m, int k) { return max_m * k * sizeof(ggml_bf16_t); }
|
||||
|
||||
BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(max_m % M_STEP == 0);
|
||||
assert(k % K_STEP == 0);
|
||||
a = reinterpret_cast<ggml_bf16_t *>(ptr);
|
||||
}
|
||||
|
||||
void from_mat(int m, ggml_bf16_t *src, int ith, int nth) {
|
||||
assert(m <= max_m);
|
||||
assert(ith == 0 && nth == 1);
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
__m512i *s = (__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin);
|
||||
__m512i *d = (__m512i *)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP +
|
||||
i * K_STEP);
|
||||
avx512_copy_32xbf16(s, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_bf16_t *get_submat(int m, int k, int m_begin, int k_begin) {
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
|
||||
k_begin -= k_block_begin;
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;
|
||||
}
|
||||
};
|
||||
|
||||
struct BufferB {
|
||||
ggml_bf16_t *b;
|
||||
int n, k;
|
||||
|
||||
static size_t required_size(int n, int k) { return n * k * sizeof(ggml_bf16_t); }
|
||||
|
||||
BufferB(int n, int k, void *ptr) : n(n), k(k) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
assert(k % K_STEP == 0);
|
||||
b = reinterpret_cast<ggml_bf16_t *>(ptr);
|
||||
}
|
||||
|
||||
void from_mat(ggml_bf16_t *src, int ith, int nth) {
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
for (int i = 0; i < N_STEP; i++) {
|
||||
__m512i *s = (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin);
|
||||
__m512i *d = (__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +
|
||||
k_begin * N_STEP + i * K_STEP);
|
||||
avx512_copy_32xbf16(s, d);
|
||||
}
|
||||
transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +
|
||||
n_begin * k_block_size + k_begin * N_STEP));
|
||||
transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +
|
||||
n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_bf16_t *get_submat(int n, int k, int n_begin, int k_begin) {
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
n_begin -= n_block_begin;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
|
||||
k_begin -= k_block_begin;
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;
|
||||
}
|
||||
};
|
||||
|
||||
struct BufferC {
|
||||
float *c;
|
||||
int max_m, n;
|
||||
|
||||
static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); }
|
||||
|
||||
BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(max_m % M_STEP == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
c = reinterpret_cast<float *>(ptr);
|
||||
}
|
||||
|
||||
void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) {
|
||||
assert(m <= max_m);
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
__m512 *x0 =
|
||||
(__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);
|
||||
__m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP +
|
||||
i * N_STEP + 16);
|
||||
avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float *get_submat(int m, int n, int m_begin, int n_begin) {
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
n_begin -= n_block_begin;
|
||||
return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct GemmKernel224Int8 {
|
||||
using dt = int8_t;
|
||||
using output_t = int32_t;
|
||||
static const int TILE_M = 16;
|
||||
static const int TILE_K = 64;
|
||||
static const int TILE_N = 16;
|
||||
static const int VNNI_BLK = 4;
|
||||
|
||||
static const int M_STEP = TILE_M * 2;
|
||||
static const int N_STEP = TILE_N * 2;
|
||||
static const int K_STEP = TILE_K;
|
||||
|
||||
static inline const int N_BLOCK = 256;
|
||||
static inline const int K_BLOCK = 3584;
|
||||
|
||||
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
|
||||
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
int n_start = N_BLOCK * ith;
|
||||
int n_end = std::min(n, N_BLOCK * (ith + 1));
|
||||
return {n_start, n_end};
|
||||
}
|
||||
|
||||
static void config() {
|
||||
enable_amx();
|
||||
TileConfig tile_config;
|
||||
|
||||
// size is 16 x 64
|
||||
for (int i = 0; i < 2; i++)
|
||||
tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));
|
||||
|
||||
// size is 16 x 64
|
||||
for (int i = 2; i < 4; i++)
|
||||
tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));
|
||||
|
||||
// size is 16 x 16
|
||||
for (int i = 4; i < 8; i++)
|
||||
tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));
|
||||
|
||||
tile_config.set_config();
|
||||
}
|
||||
|
||||
static void load_a(dt *a, size_t lda) {
|
||||
_tile_loadd(0, a, lda);
|
||||
_tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);
|
||||
}
|
||||
|
||||
static void load_b(dt *b, size_t ldb) {
|
||||
_tile_loadd(2, b, ldb);
|
||||
_tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);
|
||||
}
|
||||
|
||||
static void clean_c() {
|
||||
_tile_zero(4);
|
||||
_tile_zero(5);
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
static void load_c(output_t *c, size_t ldc) {
|
||||
_tile_loadd(4, c, ldc);
|
||||
_tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);
|
||||
_tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);
|
||||
_tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);
|
||||
}
|
||||
|
||||
static void store_c(output_t *c, size_t ldc) {
|
||||
_tile_stored(4, c, ldc);
|
||||
_tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);
|
||||
_tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);
|
||||
_tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);
|
||||
}
|
||||
|
||||
static void run_tile() {
|
||||
_tile_dpbssd(4, 0, 2);
|
||||
_tile_dpbssd(5, 0, 3);
|
||||
_tile_dpbssd(6, 1, 2);
|
||||
_tile_dpbssd(7, 1, 3);
|
||||
}
|
||||
|
||||
struct BufferA {
|
||||
int8_t *a;
|
||||
float *d;
|
||||
int max_m, k;
|
||||
|
||||
static size_t required_size(int max_m, int k) { return max_m * k * sizeof(int8_t) + max_m * sizeof(float); }
|
||||
|
||||
BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(max_m % M_STEP == 0);
|
||||
assert(k % K_STEP == 0);
|
||||
a = reinterpret_cast<int8_t *>(ptr);
|
||||
d = reinterpret_cast<float *>(a + max_m * k);
|
||||
}
|
||||
|
||||
void from_mat(int m, ggml_bf16_t *src, int ith, int nth) {
|
||||
assert(m <= max_m);
|
||||
assert(ith == 0 && nth == 1);
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < k; j += 32) {
|
||||
__m512 f0, f1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + j), &f0, &f1);
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));
|
||||
}
|
||||
d[m_begin + i] = amax / ((1 << 7) - 1);
|
||||
}
|
||||
}
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
__m512 id = _mm512_set1_ps(d[m_begin + i] ? 1.0f / d[m_begin + i] : 0.0f);
|
||||
int8_t *dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;
|
||||
__m512 f0, f1, f2, f3;
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);
|
||||
__m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));
|
||||
__m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));
|
||||
__m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));
|
||||
__m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));
|
||||
__m128i s0 = _mm512_cvtsepi32_epi8(i0);
|
||||
__m128i s1 = _mm512_cvtsepi32_epi8(i1);
|
||||
__m128i s2 = _mm512_cvtsepi32_epi8(i2);
|
||||
__m128i s3 = _mm512_cvtsepi32_epi8(i3);
|
||||
_mm_storeu_si128((__m128i *)dst, s0);
|
||||
_mm_storeu_si128((__m128i *)(dst + 16), s1);
|
||||
_mm_storeu_si128((__m128i *)(dst + 32), s2);
|
||||
_mm_storeu_si128((__m128i *)(dst + 48), s3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int8_t *get_submat(int m, int k, int m_begin, int k_begin) {
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
|
||||
k_begin -= k_block_begin;
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;
|
||||
}
|
||||
|
||||
float *get_scale(int m, int m_begin) { return d + m_begin; }
|
||||
};
|
||||
|
||||
struct BufferB {
|
||||
int8_t *b;
|
||||
float *d;
|
||||
int n, k;
|
||||
|
||||
static size_t required_size(int n, int k) { return n * k * sizeof(int8_t) + n * sizeof(float); }
|
||||
|
||||
BufferB(int n, int k, void *ptr) : n(n), k(k) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
assert(k % K_STEP == 0);
|
||||
b = reinterpret_cast<int8_t *>(ptr);
|
||||
d = reinterpret_cast<float *>(b + n * k);
|
||||
}
|
||||
|
||||
void from_mat(ggml_bf16_t *src, int ith, int nth) {
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int i = 0; i < N_STEP; i++) {
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < k; j += 32) {
|
||||
__m512 f0, f1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));
|
||||
}
|
||||
d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);
|
||||
}
|
||||
}
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
for (int i = 0; i < N_STEP; i++) {
|
||||
__m512 id = _mm512_set1_ps(d[n_block_begin + n_begin + i] ? 1.0f / d[n_block_begin + n_begin + i] : 0.0f);
|
||||
int8_t *dst = b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +
|
||||
k_begin * N_STEP + i * K_STEP;
|
||||
__m512 f0, f1, f2, f3;
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),
|
||||
&f0, &f1);
|
||||
avx512_32xbf16_to_32xfp32(
|
||||
(__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);
|
||||
__m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));
|
||||
__m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));
|
||||
__m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));
|
||||
__m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));
|
||||
__m128i s0 = _mm512_cvtsepi32_epi8(i0);
|
||||
__m128i s1 = _mm512_cvtsepi32_epi8(i1);
|
||||
__m128i s2 = _mm512_cvtsepi32_epi8(i2);
|
||||
__m128i s3 = _mm512_cvtsepi32_epi8(i3);
|
||||
_mm_storeu_si128((__m128i *)dst, s0);
|
||||
_mm_storeu_si128((__m128i *)(dst + 16), s1);
|
||||
_mm_storeu_si128((__m128i *)(dst + 32), s2);
|
||||
_mm_storeu_si128((__m128i *)(dst + 48), s3);
|
||||
}
|
||||
transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +
|
||||
n_begin * k_block_size + k_begin * N_STEP));
|
||||
transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +
|
||||
n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int8_t *get_submat(int n, int k, int n_begin, int k_begin) {
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
n_begin -= n_block_begin;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
|
||||
k_begin -= k_block_begin;
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;
|
||||
}
|
||||
|
||||
float *get_scale(int n, int n_begin) { return d + n_begin; }
|
||||
};
|
||||
|
||||
struct BufferC {
|
||||
float *c;
|
||||
int max_m, n;
|
||||
|
||||
static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); }
|
||||
|
||||
BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(max_m % M_STEP == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
c = reinterpret_cast<float *>(ptr);
|
||||
}
|
||||
|
||||
void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) {
|
||||
assert(m <= max_m);
|
||||
auto [n_start, n_end] = split_range_n(n, ith, nth);
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
__m512 *x0 =
|
||||
(__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);
|
||||
__m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP +
|
||||
i * N_STEP + 16);
|
||||
avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float *get_submat(int m, int n, int m_begin, int n_begin) {
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
n_begin -= n_block_begin;
|
||||
return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
inline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224BF::BufferB> bb, std::shared_ptr<GemmKernel224BF::BufferC> bc, int ith,
|
||||
int nth, bool use_amx) {
|
||||
using K = GemmKernel224BF;
|
||||
assert(n % K::N_STEP == 0);
|
||||
assert(k % K::K_STEP == 0);
|
||||
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {
|
||||
for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {
|
||||
for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {
|
||||
|
||||
float *c = bc->get_submat(m, n, m_begin, n_begin);
|
||||
if (!use_amx) {
|
||||
__m512 *c512 = (__m512 *)c;
|
||||
if (k_block_begin == 0) {
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_ps();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_ps();
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {
|
||||
int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||
__m512bh *b512 = (__m512bh *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
c512[m_i * 2 + n_i] = _mm512_dpbf16_ps(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
if (k_block_begin == 0) {
|
||||
K::clean_c();
|
||||
} else {
|
||||
K::load_c(c, K::N_STEP * sizeof(float));
|
||||
}
|
||||
for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {
|
||||
K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t));
|
||||
K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t));
|
||||
K::run_tile();
|
||||
}
|
||||
K::store_c(c, K::N_STEP * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline __m512i _mm512_dpbssd_epi32(__m512i src, __m512i a, __m512i b) {
|
||||
__m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
|
||||
__m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
|
||||
__m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
|
||||
__m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
|
||||
|
||||
b_lo = _mm256_sign_epi8(b_lo, a_lo);
|
||||
b_hi = _mm256_sign_epi8(b_hi, a_hi);
|
||||
|
||||
b = _mm512_inserti64x4(b, b_lo, 0);
|
||||
b = _mm512_inserti64x4(b, b_hi, 1);
|
||||
|
||||
a = _mm512_abs_epi8(a);
|
||||
|
||||
return _mm512_dpbusd_epi32(src, a, b);
|
||||
}
|
||||
|
||||
inline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int8::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224Int8::BufferB> bb, std::shared_ptr<GemmKernel224Int8::BufferC> bc,
|
||||
int ith, int nth, bool use_amx) {
|
||||
using K = GemmKernel224Int8;
|
||||
assert(n % K::N_STEP == 0);
|
||||
assert(k % K::K_STEP == 0);
|
||||
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {
|
||||
for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {
|
||||
for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {
|
||||
float *c = bc->get_submat(m, n, m_begin, n_begin);
|
||||
|
||||
if (!use_amx) {
|
||||
__m512i *c512 = (__m512i *)c;
|
||||
if (k_block_begin == 0) {
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_si512();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
|
||||
for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {
|
||||
static_assert(K::K_STEP * sizeof(int8_t) == sizeof(__m512i));
|
||||
static_assert(K::N_STEP / K::TILE_N == 2, "Must be lke this");
|
||||
|
||||
int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||
__m512i *b512 = (__m512i *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||
for (int m_i = 0; m_i < m; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);
|
||||
for (int n_i = 0; n_i < 2; n_i++) {
|
||||
c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (k_block_begin == 0) {
|
||||
K::clean_c();
|
||||
} else {
|
||||
K::load_c((int32_t *)c, K::N_STEP * sizeof(int32_t));
|
||||
}
|
||||
for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {
|
||||
K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));
|
||||
K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));
|
||||
K::run_tile();
|
||||
}
|
||||
K::store_c((int32_t *)c, K::N_STEP * sizeof(int32_t));
|
||||
}
|
||||
|
||||
if (k_block_begin + K::K_BLOCK >= k) {
|
||||
int to = m - m_begin;
|
||||
if (m - m_begin > K::M_STEP) {
|
||||
to = K::M_STEP;
|
||||
}
|
||||
for (int i = 0; i < to; i++) {
|
||||
__m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i));
|
||||
__m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin));
|
||||
__m512i now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP));
|
||||
__m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));
|
||||
_mm512_store_ps((__m512 *)(c + i * K::N_STEP), result);
|
||||
bs = _mm512_load_ps(bb->get_scale(n, n_begin) + K::TILE_N);
|
||||
now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP + K::TILE_N));
|
||||
result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));
|
||||
_mm512_store_ps((__m512 *)(c + i * K::N_STEP + K::TILE_N), result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace amx
|
46
csrc/ktransformers_ext/operators/amx/la/utils.hpp
Normal file
46
csrc/ktransformers_ext/operators/amx/la/utils.hpp
Normal file
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* @Description :
|
||||
* @Author : chenht2022
|
||||
* @Date : 2025-04-25 18:28:12
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : chenht2022
|
||||
* @LastEditTime : 2025-04-25 18:28:12
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
|
||||
template <typename T>
|
||||
T* offset_pointer(T* ptr, std::size_t byte_offset) {
|
||||
return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr) + byte_offset);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* offset_pointer(const T* ptr, std::size_t byte_offset) {
|
||||
return reinterpret_cast<const T*>(reinterpret_cast<const char*>(ptr) + byte_offset);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* offset_pointer_row_major(T* t, int row, int col, std::size_t ld) {
|
||||
return offset_pointer(t, row * ld) + col;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* offset_pointer_col_major(T* t, int row, int col, std::size_t ld) {
|
||||
return offset_pointer(t, col * ld) + row;
|
||||
}
|
||||
|
||||
static inline void avx512_copy_32xbf16(__m512i* src, __m512i* dst) {
|
||||
_mm512_storeu_si512(dst, _mm512_loadu_si512(src));
|
||||
}
|
||||
|
||||
static inline void avx512_32xfp32_to_32xbf16(__m512* src0, __m512* src1, __m512i* dst) {
|
||||
_mm512_storeu_si512(dst, __m512i(_mm512_cvtne2ps_pbh(*src1, *src0)));
|
||||
}
|
||||
|
||||
static inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512* dst1) {
|
||||
_mm512_storeu_ps(dst0, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src))), 16)));
|
||||
_mm512_storeu_ps(dst1, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src) + 1)), 16)));
|
||||
}
|
398
csrc/ktransformers_ext/operators/amx/moe.hpp
Normal file
398
csrc/ktransformers_ext/operators/amx/moe.hpp
Normal file
|
@ -0,0 +1,398 @@
|
|||
/**
|
||||
* @Description :
|
||||
* @Author : chenht2022
|
||||
* @Date : 2025-04-25 18:28:12
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : chenht2022
|
||||
* @LastEditTime : 2025-04-25 18:28:12
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AMX_MOE_H
|
||||
#define CPUINFER_OPERATOR_AMX_MOE_H
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/backend.h"
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
|
||||
#include "la/amx.hpp"
|
||||
|
||||
#ifdef USE_NUMA
|
||||
#include <numa.h>
|
||||
#include <numaif.h>
|
||||
void *numa_alloc_aligned(size_t size, int node, size_t alignment) {
|
||||
void *ptr = numa_alloc_onnode(size, node);
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
return ptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline __m512 exp_avx512(__m512 x) {
|
||||
const __m512 log2e = _mm512_set1_ps(1.44269504089f);
|
||||
const __m512 c1 = _mm512_set1_ps(0.69314718056f);
|
||||
|
||||
__m512 y = _mm512_mul_ps(x, log2e);
|
||||
__m512i int_part = _mm512_cvtps_epi32(y);
|
||||
__m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part));
|
||||
|
||||
const __m512 poly_1 = _mm512_set1_ps(0.9999999995f);
|
||||
const __m512 poly_2 = _mm512_set1_ps(0.6931471805f);
|
||||
const __m512 poly_3 = _mm512_set1_ps(0.2402265069f);
|
||||
const __m512 poly_4 = _mm512_set1_ps(0.0555041087f);
|
||||
const __m512 poly_5 = _mm512_set1_ps(0.0096181291f);
|
||||
const __m512 poly_6 = _mm512_set1_ps(0.0013333558f);
|
||||
|
||||
__m512 frac_exp = _mm512_fmadd_ps(
|
||||
frac_part, poly_6,
|
||||
_mm512_fmadd_ps(frac_part, poly_5,
|
||||
_mm512_fmadd_ps(frac_part, poly_4,
|
||||
_mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1)))));
|
||||
|
||||
__m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part));
|
||||
return _mm512_mul_ps(two_pow_i, frac_exp);
|
||||
}
|
||||
|
||||
static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
|
||||
__m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);
|
||||
__m512 exp_neg_gate = exp_avx512(neg_gate_val);
|
||||
__m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);
|
||||
__m512 act_val = _mm512_div_ps(gate_val, denom);
|
||||
|
||||
return _mm512_mul_ps(act_val, up_val);
|
||||
}
|
||||
|
||||
struct AMX_MOEConfig {
|
||||
int expert_num;
|
||||
int routed_expert_num;
|
||||
int hidden_size;
|
||||
int intermediate_size;
|
||||
int max_len;
|
||||
void *gate_proj;
|
||||
void *up_proj;
|
||||
void *down_proj;
|
||||
|
||||
AMX_MOEConfig() {}
|
||||
|
||||
AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,
|
||||
void *gate_proj, void *up_proj, void *down_proj)
|
||||
: expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),
|
||||
intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),
|
||||
down_proj(down_proj) {}
|
||||
};
|
||||
|
||||
template <class T> class AMX_MOE {
|
||||
private:
|
||||
AMX_MOEConfig config_;
|
||||
void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
|
||||
void *up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]
|
||||
void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]
|
||||
|
||||
ggml_bf16_t *m_local_input_; // [routed_expert_num * max_len * hidden_size]
|
||||
ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size]
|
||||
ggml_bf16_t *m_local_up_output_; // [routed_expert_num * max_len * intermediate_size]
|
||||
ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size]
|
||||
|
||||
std::vector<std::vector<int>> m_local_pos_; // [max_len, routed_expert_num]
|
||||
std::vector<int> m_local_num_; // [expert_num]
|
||||
std::vector<int> m_expert_id_map_; // [expert_num]
|
||||
std::vector<ggml_bf16_t *> m_local_input_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t *> m_local_gate_output_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t *> m_local_up_output_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t *> m_local_down_output_ptr_; // [expert_num]
|
||||
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
|
||||
|
||||
#ifdef USE_NUMA
|
||||
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> gate_bb_numa_;
|
||||
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> up_bb_numa_;
|
||||
std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> down_bb_numa_;
|
||||
#else
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
|
||||
#endif
|
||||
|
||||
public:
|
||||
AMX_MOE(AMX_MOEConfig config) {
|
||||
config_ = config;
|
||||
gate_proj_ = config_.gate_proj;
|
||||
up_proj_ = config_.up_proj;
|
||||
down_proj_ = config_.down_proj;
|
||||
|
||||
std::vector<std::pair<void **, uint64_t>> m_mem_requests;
|
||||
m_mem_requests.push_back({(void **)&m_local_input_,
|
||||
sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});
|
||||
m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *
|
||||
config_.max_len * config_.intermediate_size});
|
||||
m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *
|
||||
config_.max_len * config_.intermediate_size});
|
||||
m_mem_requests.push_back({(void **)&m_local_down_output_,
|
||||
sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});
|
||||
std::vector<void *> gate_up_ba_ptr(config_.expert_num);
|
||||
std::vector<void *> gate_bc_ptr(config_.expert_num);
|
||||
std::vector<void *> up_bc_ptr(config_.expert_num);
|
||||
std::vector<void *> down_ba_ptr(config_.expert_num);
|
||||
std::vector<void *> down_bc_ptr(config_.expert_num);
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_mem_requests.push_back(
|
||||
{(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});
|
||||
m_mem_requests.push_back(
|
||||
{(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});
|
||||
m_mem_requests.push_back(
|
||||
{(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});
|
||||
m_mem_requests.push_back(
|
||||
{(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});
|
||||
m_mem_requests.push_back(
|
||||
{(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});
|
||||
}
|
||||
shared_mem_buffer.alloc(this, m_mem_requests);
|
||||
|
||||
m_local_pos_.resize(config_.max_len);
|
||||
for (int i = 0; i < config_.max_len; i++) {
|
||||
m_local_pos_[i].resize(config_.routed_expert_num);
|
||||
}
|
||||
m_expert_id_map_.resize(config_.expert_num);
|
||||
m_local_num_.resize(config_.expert_num);
|
||||
m_local_input_ptr_.resize(config_.expert_num);
|
||||
m_local_gate_output_ptr_.resize(config_.expert_num);
|
||||
m_local_up_output_ptr_.resize(config_.expert_num);
|
||||
m_local_down_output_ptr_.resize(config_.expert_num);
|
||||
|
||||
for (uint64_t i = 0; i < config_.expert_num; i++) {
|
||||
gate_up_ba_.push_back(
|
||||
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i]));
|
||||
gate_bc_.push_back(
|
||||
std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, gate_bc_ptr[i]));
|
||||
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, up_bc_ptr[i]));
|
||||
down_ba_.push_back(
|
||||
std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, down_ba_ptr[i]));
|
||||
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, down_bc_ptr[i]));
|
||||
|
||||
#ifdef USE_NUMA
|
||||
int numa_nodes = numa_num_configured_nodes();
|
||||
gate_bb_numa_.resize(numa_nodes);
|
||||
up_bb_numa_.resize(numa_nodes);
|
||||
down_bb_numa_.resize(numa_nodes);
|
||||
for (int j = 0; j < numa_nodes; j++) {
|
||||
void *gate_bb_ptr =
|
||||
numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);
|
||||
gate_bb_numa_[j].push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
|
||||
void *up_bb_ptr =
|
||||
numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);
|
||||
up_bb_numa_[j].push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
|
||||
void *down_bb_ptr =
|
||||
numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);
|
||||
down_bb_numa_[j].push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
|
||||
}
|
||||
#else
|
||||
void *gate_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
|
||||
gate_bb_.push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
|
||||
|
||||
void *up_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
|
||||
up_bb_.push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
|
||||
|
||||
void *down_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
|
||||
down_bb_.push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
~AMX_MOE() { shared_mem_buffer.dealloc(this); }
|
||||
|
||||
void load_weights(Backend *backend) {
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
backend->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[&](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
int ith = task_id % nth;
|
||||
#ifdef USE_NUMA
|
||||
int numa_nodes = numa_num_configured_nodes();
|
||||
for (int j = 0; j < numa_nodes; j++) {
|
||||
gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +
|
||||
expert_idx * config_.intermediate_size * config_.hidden_size,
|
||||
ith, nth);
|
||||
up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj +
|
||||
expert_idx * config_.intermediate_size * config_.hidden_size,
|
||||
ith, nth);
|
||||
}
|
||||
#else
|
||||
gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +
|
||||
expert_idx * config_.intermediate_size * config_.hidden_size,
|
||||
ith, nth);
|
||||
up_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth);
|
||||
#endif
|
||||
},
|
||||
nullptr);
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
backend->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[&](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
int ith = task_id % nth;
|
||||
#ifdef USE_NUMA
|
||||
int numa_nodes = numa_num_configured_nodes();
|
||||
for (int j = 0; j < numa_nodes; j++) {
|
||||
down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +
|
||||
expert_idx * config_.hidden_size * config_.intermediate_size,
|
||||
ith, nth);
|
||||
}
|
||||
#else
|
||||
down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +
|
||||
expert_idx * config_.hidden_size * config_.intermediate_size,
|
||||
ith, nth);
|
||||
#endif
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
|
||||
void warm_up(Backend *backend) {}
|
||||
|
||||
void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output,
|
||||
int *batch_size_tensor, Backend *backend) {
|
||||
bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);
|
||||
qlen = batch_size_tensor[0];
|
||||
int activated_expert = 0;
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_num_[i] = 0;
|
||||
}
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
if (m_local_num_[i] > 0) {
|
||||
m_expert_id_map_[activated_expert] = i;
|
||||
activated_expert++;
|
||||
}
|
||||
}
|
||||
uint64_t offset = 0;
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
|
||||
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += m_local_num_[i];
|
||||
}
|
||||
backend->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[&](int i) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
|
||||
(ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
backend->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[&](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
|
||||
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
backend->do_work_stealing_job(
|
||||
nth * activated_expert, [&](int _) { T::config(); },
|
||||
[&](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
#ifdef USE_NUMA
|
||||
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx],
|
||||
ith, nth, use_amx);
|
||||
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith,
|
||||
nth, use_amx);
|
||||
#else
|
||||
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);
|
||||
amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);
|
||||
#endif
|
||||
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
for (int j = n_start; j < n_end; j += 32) {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = act_fn(gate_val1, up_val1);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
backend->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[&](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
backend->do_work_stealing_job(
|
||||
nth * activated_expert, [&](int _) { T::config(); },
|
||||
[&](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
#ifdef USE_NUMA
|
||||
amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],
|
||||
down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx);
|
||||
#else
|
||||
amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],
|
||||
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx);
|
||||
#endif
|
||||
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
backend->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[&](int i) {
|
||||
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||
__m512 x0 = _mm512_setzero_ps();
|
||||
__m512 x1 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
|
||||
__m512 down_output0, down_output1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size + e),
|
||||
&down_output0, &down_output1);
|
||||
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||
}
|
||||
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e));
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
|
@ -17,12 +17,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/backend.h"
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "conversion.h"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
#include "shared_mem_buffer.h"
|
||||
|
||||
struct LinearConfig {
|
||||
int input_size;
|
||||
|
|
|
@ -17,12 +17,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/backend.h"
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "conversion.h"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
#include "shared_mem_buffer.h"
|
||||
|
||||
struct MLPConfig {
|
||||
int hidden_size;
|
||||
|
|
|
@ -17,12 +17,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/backend.h"
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "conversion.h"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "llama.cpp/ggml-quants.h"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "llamafile/sgemm.h"
|
||||
#include "shared_mem_buffer.h"
|
||||
|
||||
struct MOEConfig {
|
||||
int expert_num;
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
/**
|
||||
* @Description :
|
||||
* @Author : chenht2022
|
||||
* @Date : 2024-08-05 04:49:08
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : chenht2022
|
||||
* @LastEditTime : 2024-08-05 06:36:41
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
|
||||
#ifndef CPUINFER_SHAREDMEMBUFFER_H
|
||||
#define CPUINFER_SHAREDMEMBUFFER_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
class SharedMemBuffer {
|
||||
public:
|
||||
SharedMemBuffer();
|
||||
~SharedMemBuffer();
|
||||
|
||||
void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);
|
||||
void dealloc(void* object);
|
||||
|
||||
private:
|
||||
void* buffer_;
|
||||
uint64_t size_;
|
||||
std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;
|
||||
|
||||
void arrange(std::vector<std::pair<void**, uint64_t>> requests);
|
||||
};
|
||||
|
||||
static SharedMemBuffer shared_mem_buffer;
|
||||
|
||||
#endif
|
|
@ -25,8 +25,9 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
|
|||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
||||
import cpuinfer_ext
|
||||
from cpuinfer_ext.moe import MOEConfig, MOE
|
||||
from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE, AMXInt8_MOE
|
||||
import ctypes
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.custom_gguf import GGMLQuantizationType, GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.server.config.config import Config
|
||||
from transformers.activations import ACT2FN
|
||||
|
@ -141,6 +142,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.out_device = out_device
|
||||
self.backend = kwargs.get("backend", "llamafile")
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
|
||||
if device:
|
||||
|
@ -163,7 +165,9 @@ class KExpertsCPU(KExpertsBase):
|
|||
)
|
||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||
n_routed_experts = self.n_routed_experts
|
||||
self.cpu_infer = KExpertsCPU.CPU_INFER
|
||||
# n_routed_experts = len(self.orig_module)
|
||||
if self.backend == "llamafile":
|
||||
moe_config = MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
|
@ -180,10 +184,43 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.down_type,
|
||||
30, # TODO: get from model.dtype
|
||||
)
|
||||
self.moe = MOE(moe_config)
|
||||
elif self.backend == "AMXBF16":
|
||||
assert self.gate_type == GGMLQuantizationType.BF16
|
||||
assert self.up_type == GGMLQuantizationType.BF16
|
||||
assert self.down_type == GGMLQuantizationType.BF16
|
||||
moe_config = AMX_MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
25600,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
)
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
self.cpu_infer.submit(self.moe.load_weights())
|
||||
self.cpu_infer.sync()
|
||||
elif self.backend == "AMXInt8":
|
||||
assert self.gate_type == GGMLQuantizationType.BF16
|
||||
assert self.up_type == GGMLQuantizationType.BF16
|
||||
assert self.down_type == GGMLQuantizationType.BF16
|
||||
moe_config = AMX_MOEConfig(
|
||||
n_routed_experts,
|
||||
self.config.num_experts_per_tok,
|
||||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
25600,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
)
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
self.cpu_infer.submit(self.moe.load_weights())
|
||||
self.cpu_infer.sync()
|
||||
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
||||
num_experts_per_tok = self.config.num_experts_per_tok
|
||||
self.moe = MOE(moe_config)
|
||||
self.cpu_infer = KExpertsCPU.CPU_INFER
|
||||
if warmup:
|
||||
self.cpu_infer.submit(self.moe.warm_up())
|
||||
self.cpu_infer.sync()
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
Loading…
Add table
Add a link
Reference in a new issue