kvcache-ai-ktransformers/kt-kernel/operators/moe_kernel/test/utils_test.cpp
2025-11-03 15:19:52 +08:00

63 lines
2.4 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// #pragma once
#ifdef TEST_UTIL
#include <arm_neon.h>
#include <arm_sve.h>
#include <stdio.h>
static inline void sve_32xbf16_to_32xfp32(const bfloat16_t* src, float* dst0, float* dst1) {
#ifdef __ARM_FEATURE_SVE
// 全真谓词,对应每个 16bit 元素
#else
// fallback: scalar or NEON
#endif
}
static inline void neon_32xbf16_to_32xfp32(const uint16_t* src, float* dst0, float* dst1) {
// src 指向 32 个连续的 BF16uint16_t
// dst0、dst1 各指向 16 个 float 的缓冲
for (int block = 0; block < 4; ++block) {
// 每次处理 8 个 BF16 → 8 个 FP32拆为两次 4→4 存储)
uint16x8_t v_bf16 = vld1q_u16(src + block * 8); // load 8×BF16 :contentReference[oaicite:6]{index=6}
// 拆低半、高半各 4 个到 u32
uint32x4_t lo_u32 = vmovl_u16(vget_low_u16(v_bf16)); // lower 4 → u32 :contentReference[oaicite:7]{index=7}
uint32x4_t hi_u32 = vmovl_u16(vget_high_u16(v_bf16)); // upper 4 → u32 :contentReference[oaicite:8]{index=8}
// 左移 16 位,相当于将 BF16 的 16 位 mantissa+exp 放到 FP32 高位
lo_u32 = vshlq_n_u32(lo_u32, 16); // shift left 16 :contentReference[oaicite:9]{index=9}
hi_u32 = vshlq_n_u32(hi_u32, 16); // shift left 16 :contentReference[oaicite:10]{index=10}
// 重新解释为 float32x4_t
float32x4_t lo_f32 = vreinterpretq_f32_u32(lo_u32); // bits → FP32 :contentReference[oaicite:11]{index=11}
float32x4_t hi_f32 = vreinterpretq_f32_u32(hi_u32); // bits → FP32 :contentReference[oaicite:12]{index=12}
// 存储到 dst0 或 dst1每次存 8 个
if (block < 2) {
vst1q_f32(dst0 + block * 4, lo_f32); // store 4 floats :contentReference[oaicite:13]{index=13}
vst1q_f32(dst0 + block * 4 + 4, hi_f32); // store next 4 floats :contentReference[oaicite:14]{index=14}
} else {
int b = block - 2;
vst1q_f32(dst1 + b * 4, lo_f32); // store 4 floats :contentReference[oaicite:15]{index=15}
vst1q_f32(dst1 + b * 4 + 4, hi_f32); // store next 4 floats :contentReference[oaicite:16]{index=16}
}
}
}
int main() {
// 测试代码
uint16_t bf16_data[32] = {0}; // 假设这里填充了一些 BF16 数据
float f32_data0[16] = {0};
float f32_data1[16] = {0};
neon_32xbf16_to_32xfp32(bf16_data, f32_data0, f32_data1);
// 打印结果
for (int i = 0; i < 16; ++i) {
printf("f32_data0[%d]: %f\n", i, f32_data0[i]);
printf("f32_data1[%d]: %f\n", i, f32_data1[i]);
}
return 0;
}
#endif