koboldcpp/otherarch/ggml_v3.c
2024-09-19 13:56:19 +08:00

28860 lines
1,014 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
#define _USE_MATH_DEFINES // For M_PI on MSVC
/// Start ggml-impl.h
#include "ggml_v3.h"
// GGML internal header
#include <assert.h>
#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
#include <stddef.h>
#include <stdbool.h>
#include <string.h> // memcpy
#include <math.h> // fabsf
#ifdef __cplusplus
extern "C" {
#endif
// static_assert should be a #define, but if it's not,
// fall back to the _Static_assert C11 keyword.
// if C99 - static_assert is noop
// ref: https://stackoverflow.com/a/53923785/4039976
#ifndef static_assert
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
#define static_assert(cond, msg) _Static_assert(cond, msg)
#else
#define static_assert(cond, msg) struct global_scope_noop_trick
#endif
#endif
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
#ifndef __FMA__
#define __FMA__
#endif
#ifndef __F16C__
#define __F16C__
#endif
#ifndef __SSE3__
#define __SSE3__
#endif
#endif
// 16-bit float
// on Arm, we use __fp16
// on x86, we use uint16_t
#if defined(__ARM_NEON) && !defined(_MSC_VER)
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
//
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
//
#include <arm_neon.h>
#define GGML_V3_COMPUTE_FP16_TO_FP32(x) ((float) (x))
#define GGML_V3_COMPUTE_FP32_TO_FP16(x) (x)
#define GGML_V3_FP16_TO_FP32(x) ((float) (x))
#define GGML_V3_FP32_TO_FP16(x) (x)
#else
#ifdef __wasm_simd128__
#include <wasm_simd128.h>
#else
#ifdef __POWER9_VECTOR__
#include <altivec.h>
#undef bool
#define bool _Bool
#else
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#else
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
#if !defined(__riscv)
#include <immintrin.h>
#endif
#endif
#endif
#endif
#endif
#ifdef __riscv_v_intrinsic
#include <riscv_vector.h>
#endif
#ifdef __F16C__
#ifdef _MSC_VER
#define GGML_V3_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
#define GGML_V3_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
#else
#define GGML_V3_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
#define GGML_V3_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
#endif
#elif defined(__POWER9_VECTOR__)
#define GGML_V3_COMPUTE_FP16_TO_FP32(x) ggml_v3_compute_fp16_to_fp32(x)
#define GGML_V3_COMPUTE_FP32_TO_FP16(x) ggml_v3_compute_fp32_to_fp16(x)
/* the inline asm below is about 12% faster than the lookup method */
#define GGML_V3_FP16_TO_FP32(x) GGML_V3_COMPUTE_FP16_TO_FP32(x)
#define GGML_V3_FP32_TO_FP16(x) GGML_V3_COMPUTE_FP32_TO_FP16(x)
static inline float ggml_v3_compute_fp16_to_fp32(ggml_v3_fp16_t h) {
register float f;
register double d;
__asm__(
"mtfprd %0,%2\n"
"xscvhpdp %0,%0\n"
"frsp %1,%0\n" :
/* temp */ "=d"(d),
/* out */ "=f"(f):
/* in */ "r"(h));
return f;
}
static inline ggml_v3_fp16_t ggml_v3_compute_fp32_to_fp16(float f) {
register double d;
register ggml_v3_fp16_t r;
__asm__( /* xscvdphp can work on double or single precision */
"xscvdphp %0,%2\n"
"mffprd %1,%0\n" :
/* temp */ "=d"(d),
/* out */ "=r"(r):
/* in */ "f"(f));
return r;
}
#else
// FP16 <-> FP32
// ref: https://github.com/Maratyszcza/FP16
static inline float fp32_from_bits(uint32_t w) {
union {
uint32_t as_bits;
float as_value;
} fp32;
fp32.as_bits = w;
return fp32.as_value;
}
static inline uint32_t fp32_to_bits(float f) {
union {
float as_value;
uint32_t as_bits;
} fp32;
fp32.as_value = f;
return fp32.as_bits;
}
static inline float ggml_v3_compute_fp16_to_fp32(ggml_v3_fp16_t h) {
const uint32_t w = (uint32_t) h << 16;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t two_w = w + w;
const uint32_t exp_offset = UINT32_C(0xE0) << 23;
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float exp_scale = 0x1.0p-112f;
#else
const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
#endif
const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
const uint32_t magic_mask = UINT32_C(126) << 23;
const float magic_bias = 0.5f;
const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result = sign |
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
return fp32_from_bits(result);
}
static inline ggml_v3_fp16_t ggml_v3_compute_fp32_to_fp16(float f) {
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
const float scale_to_inf = 0x1.0p+112f;
const float scale_to_zero = 0x1.0p-110f;
#else
const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
#endif
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
const uint32_t w = fp32_to_bits(f);
const uint32_t shl1_w = w + w;
const uint32_t sign = w & UINT32_C(0x80000000);
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
if (bias < UINT32_C(0x71000000)) {
bias = UINT32_C(0x71000000);
}
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
const uint32_t bits = fp32_to_bits(base);
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
const uint32_t nonsign = exp_bits + mantissa_bits;
return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
}
#define GGML_V3_COMPUTE_FP16_TO_FP32(x) ggml_v3_compute_fp16_to_fp32(x)
#define GGML_V3_COMPUTE_FP32_TO_FP16(x) ggml_v3_compute_fp32_to_fp16(x)
#endif // __F16C__
#endif // __ARM_NEON
// precomputed f32 table for f16 (256 KB)
// defined in ggml.c, initialized in ggml_v3_init()
extern float ggml_v3_table_f32_f16[1 << 16];
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_v3_lookup_fp16_to_fp32,
// so we define GGML_V3_FP16_TO_FP32 and GGML_V3_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
#if !defined(GGML_V3_FP16_TO_FP32) || !defined(GGML_V3_FP32_TO_FP16)
inline static float ggml_v3_lookup_fp16_to_fp32(ggml_v3_fp16_t f) {
uint16_t s;
memcpy(&s, &f, sizeof(uint16_t));
return ggml_v3_table_f32_f16[s];
}
#define GGML_V3_FP16_TO_FP32(x) ggml_v3_lookup_fp16_to_fp32(x)
#define GGML_V3_FP32_TO_FP16(x) GGML_V3_COMPUTE_FP32_TO_FP16(x)
#endif
#define GGML_V3_HASHTABLE_FULL ((size_t)-1)
#define GGML_V3_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
bool ggml_v3_hash_contains (const struct ggml_v3_hash_set hash_set, struct ggml_v3_tensor * key);
// returns GGML_V3_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
size_t ggml_v3_hash_find (const struct ggml_v3_hash_set hash_set, struct ggml_v3_tensor * key);
// returns GGML_V3_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
size_t ggml_v3_hash_insert ( struct ggml_v3_hash_set hash_set, struct ggml_v3_tensor * key);
// return index, asserts if table is full
size_t ggml_v3_hash_find_or_insert( struct ggml_v3_hash_set hash_set, struct ggml_v3_tensor * key);
//allocator stuff
#include <assert.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
#if defined(_POSIX_MAPPED_FILES)
#include <sys/types.h>
#include <sys/mman.h>
#endif
#endif
#endif
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <memoryapi.h>
#endif
#define UNUSED GGML_V3_UNUSED
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define GGML_V3_MAX_CONCUR (2*GGML_V3_MAX_NODES)
//#define GGML_V3_ALLOCATOR_DEBUG
//#define AT_PRINTF printf
#define AT_PRINTF(...) ((void)0)
struct hash_node_v3 {
struct ggml_v3_tensor * t;
int n_children;
int n_views;
};
static size_t hash(void * p) {
return (size_t)p % GGML_V3_GRAPH_HASHTABLE_SIZE;
}
static struct hash_node_v3 * hash_get(struct hash_node_v3 hash_table[], struct ggml_v3_tensor * t) {
size_t h = hash(t);
// linear probing
size_t i = h;
while (hash_table[i].t != NULL) {
if (hash_table[i].t == t) {
return &hash_table[i];
}
i = (i + 1) % GGML_V3_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// hash table is full
GGML_V3_ASSERT(false);
}
}
hash_table[i].t = t;
return &hash_table[i];
}
// TODO: GGML_V3_PAD ?
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
assert(alignment && !(alignment & (alignment - 1))); // power of 2
size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
return offset + align;
}
struct free_block_v3 {
void * addr;
size_t size;
};
#define MAX_FREE_BLOCKS 256
struct ggml_v3_allocr {
void * data;
size_t size;
size_t alignment;
int n_free_block_v3s;
struct free_block_v3 free_block_v3s[MAX_FREE_BLOCKS];
struct hash_node_v3 hash_table[GGML_V3_GRAPH_HASHTABLE_SIZE];
size_t max_size;
bool measure;
int parse_seq[GGML_V3_MAX_CONCUR];
int parse_seq_len;
#ifdef GGML_V3_ALLOCATOR_DEBUG
struct ggml_v3_tensor * allocated_tensors[1024];
#endif
};
#ifdef GGML_V3_ALLOCATOR_DEBUG
static void add_allocated_tensor(struct ggml_v3_allocr * alloc, struct ggml_v3_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == NULL) {
alloc->allocated_tensors[i] = tensor;
return;
}
}
GGML_V3_ASSERT(!"out of allocated_tensors");
}
static void remove_allocated_tensor(struct ggml_v3_allocr * alloc, struct ggml_v3_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == tensor ||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
alloc->allocated_tensors[i] = NULL;
return;
}
}
printf("tried to free tensor %s not found\n", tensor->name);
GGML_V3_ASSERT(!"tensor not found");
}
#endif
static size_t ggml_v3_allocr_get_alloc_size(struct ggml_v3_allocr * alloc, struct ggml_v3_tensor * tensor) {
return ggml_v3_nbytes(tensor);
UNUSED(alloc);
}
// check if a tensor is allocated by this buffer
static bool ggml_v3_allocr_is_own(struct ggml_v3_allocr * alloc, const struct ggml_v3_tensor * tensor) {
void * ptr = tensor->data;
return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
}
static bool ggml_v3_is_view(struct ggml_v3_tensor * t) {
return t->view_src != NULL;
}
void ggml_v3_allocr_alloc(struct ggml_v3_allocr * alloc, struct ggml_v3_tensor * tensor) {
#ifdef GGML_V3_ALLOCATOR_DEBUG
GGML_V3_ASSERT(!ggml_v3_is_view(tensor)); // views generally get data pointer from one of their sources
GGML_V3_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
#endif
size_t size = ggml_v3_allocr_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
size_t max_avail = 0;
// find the best fitting free block besides the last block
int best_fit_block = -1;
size_t best_fit_size = SIZE_MAX;
for (int i = 0; i < alloc->n_free_block_v3s - 1; i++) {
struct free_block_v3 * block = &alloc->free_block_v3s[i];
max_avail = MAX(max_avail, block->size);
if (block->size >= size && block->size <= best_fit_size) {
best_fit_block = i;
best_fit_size = block->size;
}
}
AT_PRINTF("block %d\n", best_fit_block);
if (best_fit_block == -1) {
// the last block is our last resort
struct free_block_v3 * block = &alloc->free_block_v3s[alloc->n_free_block_v3s - 1];
max_avail = MAX(max_avail, block->size);
if (block->size >= size) {
best_fit_block = alloc->n_free_block_v3s - 1;
} else {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
__func__, size, max_avail);
GGML_V3_ASSERT(!"not enough space in the buffer");
return;
}
}
struct free_block_v3 * block = &alloc->free_block_v3s[best_fit_block];
void * addr = block->addr;
block->addr = (char*)block->addr + size;
block->size -= size;
if (block->size == 0) {
// remove block if empty
alloc->n_free_block_v3s--;
for (int j = best_fit_block; j < alloc->n_free_block_v3s; j++) {
alloc->free_block_v3s[j] = alloc->free_block_v3s[j+1];
}
}
tensor->data = addr;
AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data);
#ifdef GGML_V3_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
size_t cur_max = (char*)addr - (char*)alloc->data + size;
if (cur_max > alloc->max_size) {
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i]) {
printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_v3_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
}
}
printf("\n");
}
#endif
alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
}
// this is a very naive implementation, but for our case the number of free blocks should be very small
static void ggml_v3_allocr_free_tensor(struct ggml_v3_allocr * alloc, struct ggml_v3_tensor * tensor) {
void * ptr = tensor->data;
if (ggml_v3_allocr_is_own(alloc, tensor) == false) {
// the tensor was not allocated in this buffer
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
// the easiest way to deal with this is just to ignore it
return;
}
size_t size = ggml_v3_allocr_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_block_v3s = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_block_v3s);
AT_PRINTF("%s: alloc->data = %p alloc->data+alloc->size = %p alloc->data+alloc->max_size = %p\n", __func__, alloc->data, (char*)alloc->data + alloc->size, (char*)alloc->data + alloc->max_size);
#ifdef GGML_V3_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor);
#endif
// see if we can merge with an existing block
for (int i = 0; i < alloc->n_free_block_v3s; i++) {
struct free_block_v3 * block = &alloc->free_block_v3s[i];
// check if ptr is at the end of the block
if ((char*)block->addr + block->size == ptr) {
block->size += size;
// check if we can merge with the next block
if (i < alloc->n_free_block_v3s - 1 && (char*)block->addr + block->size == alloc->free_block_v3s[i+1].addr) {
block->size += alloc->free_block_v3s[i+1].size;
alloc->n_free_block_v3s--;
for (int j = i+1; j < alloc->n_free_block_v3s; j++) {
alloc->free_block_v3s[j] = alloc->free_block_v3s[j+1];
}
}
return;
}
// check if ptr is at the beginning of the block
if ((char*)ptr + size == block->addr) {
block->addr = ptr;
block->size += size;
// check if we can merge with the previous block
if (i > 0 && (char*)alloc->free_block_v3s[i-1].addr + alloc->free_block_v3s[i-1].size == block->addr) {
alloc->free_block_v3s[i-1].size += block->size;
alloc->n_free_block_v3s--;
for (int j = i; j < alloc->n_free_block_v3s; j++) {
alloc->free_block_v3s[j] = alloc->free_block_v3s[j+1];
}
}
return;
}
}
// otherwise, add a new block
GGML_V3_ASSERT(alloc->n_free_block_v3s < MAX_FREE_BLOCKS && "out of free blocks");
// insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
int insert_pos = 0;
while (insert_pos < alloc->n_free_block_v3s && alloc->free_block_v3s[insert_pos].addr < ptr) {
insert_pos++;
}
// shift all blocks from insert_pos onward to make room for the new block
for (int i = alloc->n_free_block_v3s; i > insert_pos; i--) {
alloc->free_block_v3s[i] = alloc->free_block_v3s[i-1];
}
// insert the new block
alloc->free_block_v3s[insert_pos].addr = ptr;
alloc->free_block_v3s[insert_pos].size = size;
alloc->n_free_block_v3s++;
}
void ggml_v3_allocr_set_parse_seq(struct ggml_v3_allocr * alloc, const int * list, int n) {
for (int i = 0; i < n; i++) {
alloc->parse_seq[i] = list[i];
}
alloc->parse_seq_len = n;
}
void ggml_v3_allocr_reset(struct ggml_v3_allocr * alloc) {
alloc->n_free_block_v3s = 1;
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
alloc->free_block_v3s[0].addr = (char *)alloc->data + align_offset;
alloc->free_block_v3s[0].size = alloc->size - align_offset;
}
struct ggml_v3_allocr * ggml_v3_allocr_new(void * data, size_t size, size_t alignment) {
struct ggml_v3_allocr * alloc = (struct ggml_v3_allocr *)malloc(sizeof(struct ggml_v3_allocr) /* + n_free_block_v3s * sizeof(struct free_block_v3) */);
*alloc = (struct ggml_v3_allocr){
/*.data = */ data,
/*.size = */ size,
/*.alignment = */ alignment,
/*.n_free_block_v3s = */ 0,
/*.free_block_v3s = */ {{0}},
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
/*.parse_seq = */ {0},
/*.parse_seq_len = */ 0,
#ifdef GGML_V3_ALLOCATOR_DEBUG
/*.allocated_tensors = */ {0},
#endif
};
ggml_v3_allocr_reset(alloc);
return alloc;
}
// OS specific functions to allocate and free uncommitted virtual memory
static void * alloc_vmem(size_t size) {
#if defined(_WIN32)
return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
#elif defined(_POSIX_MAPPED_FILES)
void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
if (ptr == MAP_FAILED) {
return NULL;
}
return ptr;
#else
// use a fixed address for other platforms
uintptr_t base_addr = (uintptr_t)-size - 0x100;
return (void *)base_addr;
#endif
}
static void free_vmem(void * base_addr, size_t size) {
#if defined(_WIN32)
VirtualFree(base_addr, 0, MEM_RELEASE);
UNUSED(size);
#elif defined(_POSIX_MAPPED_FILES)
munmap(base_addr, size);
#else
// nothing to do
UNUSED(base_addr);
UNUSED(size);
#endif
}
// allocate uncommitted virtual memory to measure the size of the graph
static void alloc_measure_vmem(void ** base_addr, size_t * size) {
// 128GB for 64-bit, 1GB for 32-bit
*size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
do {
*base_addr = alloc_vmem(*size);
if (*base_addr != NULL) {
AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
return;
}
// try again with half the size
*size /= 2;
} while (*size > 0);
GGML_V3_ASSERT(!"failed to allocate virtual memory for measure buffer");
}
static void free_measure_vmem(void * base_addr, size_t size) {
free_vmem(base_addr, size);
}
struct ggml_v3_allocr * ggml_v3_allocr_new_measure(size_t alignment) {
struct ggml_v3_allocr * alloc = (struct ggml_v3_allocr *)malloc(sizeof(struct ggml_v3_allocr) /* + n_free_block_v3s * sizeof(struct free_block_v3) */);
void * base_addr;
size_t size;
alloc_measure_vmem(&base_addr, &size);
*alloc = (struct ggml_v3_allocr){
/*.data = */ base_addr,
/*.size = */ size,
/*.alignment = */ alignment,
/*.n_free_block_v3s = */ 0,
/*.free_block_v3s = */ {{0}},
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ true,
/*.parse_seq = */ {0},
/*.parse_seq_len = */ 0,
#ifdef GGML_V3_ALLOCATOR_DEBUG
/*.allocated_tensors = */ {0},
#endif
};
ggml_v3_allocr_reset(alloc);
return alloc;
}
void ggml_v3_allocr_free(struct ggml_v3_allocr * alloc) {
if (alloc->measure) {
free_measure_vmem(alloc->data, alloc->size);
}
free(alloc);
}
bool ggml_v3_allocr_is_measure(struct ggml_v3_allocr * alloc) {
return alloc->measure;
}
//////////// compute graph allocator
static bool ggml_v3_are_same_layout(const struct ggml_v3_tensor * a, const struct ggml_v3_tensor * b) {
if (a->type != b->type) {
return false;
}
for (int i = 0; i < GGML_V3_MAX_DIMS; i++) {
if (a->ne[i] != b->ne[i]) {
return false;
}
if (a->nb[i] != b->nb[i]) {
return false;
}
}
return true;
}
static bool ggml_v3_op_can_inplace(enum ggml_v3_op op) {
switch (op) {
case GGML_V3_OP_SCALE:
case GGML_V3_OP_DIAG_MASK_ZERO:
case GGML_V3_OP_DIAG_MASK_INF:
case GGML_V3_OP_ADD:
case GGML_V3_OP_ADD1:
case GGML_V3_OP_SUB:
case GGML_V3_OP_MUL:
case GGML_V3_OP_DIV:
case GGML_V3_OP_SQR:
case GGML_V3_OP_SQRT:
case GGML_V3_OP_LOG:
case GGML_V3_OP_UNARY:
case GGML_V3_OP_ROPE:
case GGML_V3_OP_RMS_NORM:
case GGML_V3_OP_SOFT_MAX:
case GGML_V3_OP_CONT:
return true;
default:
return false;
}
}
static void allocate_node_v3(struct ggml_v3_allocr * alloc, struct ggml_v3_tensor * node) {
struct hash_node_v3 * ht = alloc->hash_table;
if (node->data == NULL) {
if (ggml_v3_is_view(node)) {
assert(node->view_src->data != NULL);
node->data = (char *)node->view_src->data + node->view_offs;
} else {
// see if we can reuse a parent's buffer (inplace)
if (ggml_v3_op_can_inplace(node->op)) {
for (int i = 0; i < GGML_V3_MAX_SRC; i++) {
struct ggml_v3_tensor * parent = node->src[i];
if (parent == NULL) {
break;
}
// if the node's data is external, then we cannot re-use it
if (ggml_v3_allocr_is_own(alloc, parent) == false) {
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
continue;
}
struct hash_node_v3 * p_hn = hash_get(ht, parent);
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_v3_are_same_layout(node, parent)) {
if (ggml_v3_is_view(parent)) {
struct ggml_v3_tensor * view_src = parent->view_src;
struct hash_node_v3 * view_src_hn = hash_get(ht, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
// the parent's data that it will need later (same layout requirement). the problem is that then
// we cannot free the tensor because the original address of the allocation is lost.
// adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
// for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
node->data = parent->data;
return;
}
}
else {
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
node->data = parent->data;
return;
}
}
}
}
ggml_v3_allocr_alloc(alloc, node);
}
}
}
static size_t ggml_v3_allocr_alloc_graph_tensors_n(
struct ggml_v3_allocr * alloc,
struct ggml_v3_cgraph ** graphs, int n_graphs,
struct ggml_v3_tensor *** inputs, struct ggml_v3_tensor *** outputs) {
// reset hash table
struct hash_node_v3 * ht = alloc->hash_table;
memset(ht, 0, sizeof(struct hash_node_v3) * GGML_V3_GRAPH_HASHTABLE_SIZE);
// count number of children and views
for (int g = 0; g < n_graphs; g++) {
struct ggml_v3_cgraph * gf = graphs[g];
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_v3_tensor * node = gf->nodes[i];
if (ggml_v3_is_view(node)) {
struct ggml_v3_tensor * view_src = node->view_src;
hash_get(ht, view_src)->n_views += 1;
}
for (int j = 0; j < GGML_V3_MAX_SRC; j++) {
struct ggml_v3_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
hash_get(ht, parent)->n_children += 1;
}
}
}
// allocate tensors
for (int g = 0; g < n_graphs; g++) {
struct ggml_v3_cgraph * gf = graphs[g];
AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
// graph inputs are allocated first to ensure that they are not overwritten by each other
if (inputs != NULL && inputs[g] != NULL) {
for (int i = 0; inputs[g][i] != NULL; i++) {
struct ggml_v3_tensor * input = inputs[g][i];
AT_PRINTF("input: %s\n", input->name);
allocate_node_v3(alloc, input);
}
}
// if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
int last_barrier_pos = 0;
int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
for (int ind = 0; ind < n_nodes; ind++) {
// allocate a node if there is no parse_seq or this is not a barrier
if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
struct ggml_v3_tensor * node = gf->nodes[i];
// allocate parents (leafs)
for (int j = 0; j < GGML_V3_MAX_SRC; j++) {
struct ggml_v3_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
allocate_node_v3(alloc, parent);
}
// allocate node
allocate_node_v3(alloc, node);
AT_PRINTF("exec: %s (%s) <= ", ggml_v3_op_name(node->op), node->name);
for (int j = 0; j < GGML_V3_MAX_SRC; j++) {
struct ggml_v3_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
AT_PRINTF("%s", parent->name);
if (j < GGML_V3_MAX_SRC - 1 && node->src[j + 1] != NULL) {
AT_PRINTF(", ");
}
}
AT_PRINTF("\n");
}
// update parents
// update immediately if there is no parse_seq
// update only at barriers if there is parse_seq
if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
int update_end = alloc->parse_seq_len ? ind : ind + 1;
for (int i = update_start; i < update_end; i++) {
int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
struct ggml_v3_tensor * node = gf->nodes[node_i];
for (int j = 0; j < GGML_V3_MAX_SRC; j++) {
struct ggml_v3_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
struct hash_node_v3 * p_hn = hash_get(ht, parent);
p_hn->n_children -= 1;
//AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_v3_is_view(parent)) {
struct ggml_v3_tensor * view_src = parent->view_src;
struct hash_node_v3 * view_src_hn = hash_get(ht, view_src);
view_src_hn->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
ggml_v3_allocr_free_tensor(alloc, view_src);
}
}
else {
if (parent->data != node->data) {
ggml_v3_allocr_free_tensor(alloc, parent);
}
}
}
}
}
AT_PRINTF("\n");
if (alloc->parse_seq_len) {
last_barrier_pos = ind + 1;
}
}
}
// free graph outputs here that wouldn't be freed otherwise because they have no children
if (outputs != NULL && outputs[g] != NULL) {
for (int i = 0; outputs[g][i] != NULL; i++) {
struct ggml_v3_tensor * output = outputs[g][i];
AT_PRINTF("output: %s\n", output->name);
ggml_v3_allocr_free_tensor(alloc, output);
}
}
}
return alloc->max_size;
}
size_t ggml_v3_allocr_alloc_graph(struct ggml_v3_allocr * alloc, struct ggml_v3_cgraph * graph) {
return ggml_v3_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
}
size_t ggml_v3_allocr_max_size(struct ggml_v3_allocr * alloc) {
return alloc->max_size;
}
#ifdef __cplusplus
}
#endif
/// end ggml-imph.h
#include <stddef.h>
#include <stdint.h>
#define QK4_0 32
typedef struct {
ggml_v3_fp16_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(ggml_v3_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
typedef struct {
ggml_v3_fp16_t d; // delta
ggml_v3_fp16_t m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_v3_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK5_0 32
typedef struct {
ggml_v3_fp16_t d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;
static_assert(sizeof(block_q5_0) == sizeof(ggml_v3_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
#define QK5_1 32
typedef struct {
ggml_v3_fp16_t d; // delta
ggml_v3_fp16_t m; // min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_v3_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
#define QK8_0 32
typedef struct {
ggml_v3_fp16_t d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(ggml_v3_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
#define QK8_1 32
typedef struct {
float d; // delta
float s; // d * sum(qs[i])
int8_t qs[QK8_1]; // quants
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
//
// Super-block quantization structures
//
// Super-block size
#ifdef GGML_QKK_64
#define QK_K 64
#define K_SCALE_SIZE 4
#else
#define QK_K 256
#define K_SCALE_SIZE 12
#endif
// 2-bit quantization
// weight is represented as x = a * q + b
// 16 blocks of 16 elements each
// Effectively 2.625 bits per weight
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
ggml_v3_fp16_t d; // super-block scale for quantized scales
ggml_v3_fp16_t dmin; // super-block scale for quantized mins
} block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_v3_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
// 3-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 3.4375 bits per weight
#ifdef GGML_QKK_64
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[2];
ggml_v3_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_v3_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
#else
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[12]; // scales, quantized with 6 bits
ggml_v3_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_v3_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
#endif
// 4-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_v3_fp16_t d[2]; // super-block scales/mins
uint8_t scales[2]; // 4-bit block scales/mins
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_v3_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
#else
typedef struct {
ggml_v3_fp16_t d; // super-block scale for quantized scales
ggml_v3_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_v3_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
#endif
// 5-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 5.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_v3_fp16_t d; // super-block scale
int8_t scales[QK_K/16]; // 8-bit block scales
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == sizeof(ggml_v3_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
#else
typedef struct {
ggml_v3_fp16_t d; // super-block scale for quantized scales
ggml_v3_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_v3_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#endif
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
ggml_v3_fp16_t d; // super-block scale
} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(ggml_v3_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
// This is only used for intermediate quantization and dot products
typedef struct {
float d; // delta
int8_t qs[QK_K]; // quants
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
} block_q8_K;
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
// (Almost) "true" 2-bit quantization.
// Due to the need to use blocks as per ggml dsign, it ends up using
// 2.0625 bpw because of the 16-bit scale for each block of 256.
typedef struct {
ggml_v3_fp16_t d;
uint16_t qs[QK_K/8];
} block_iq2_xxs;
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_v3_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
// 2.3125 bpw quants
typedef struct {
ggml_v3_fp16_t d;
uint16_t qs[QK_K/8];
uint8_t scales[QK_K/32];
} block_iq2_xs;
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_v3_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
// Quantization
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k);
static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k);
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k);
static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k);
static void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
static void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
static void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
static void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
static void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
static void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
static void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
static void quantize_row_iq2_xs_reference (const float * restrict x, block_iq2_xs * restrict y, int k);
static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k);
static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k);
static void quantize_row_q8_0(const float * restrict x, void * restrict y, int k);
static void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
static void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
static void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
static void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
static void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
static void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
static void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
static void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
static void quantize_row_iq2_xs (const float * restrict x, void * restrict y, int k);
// Dequantization
static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k);
static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k);
static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k);
static void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k);
//static void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k);
static void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
static void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
static void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
static void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
static void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
static void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
static void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
static void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
// Dot product
static void ggml_v3_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_v3_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
#include <alloca.h>
#endif
#include <assert.h>
#include <errno.h>
#include <time.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <inttypes.h>
#include <stdio.h>
#include <float.h>
#include <limits.h>
#include <stdarg.h>
#include <signal.h>
#ifdef GGML_USE_METAL
#include <unistd.h>
#endif
#if defined(_MSC_VER)
// disable "possible loss of data" to avoid hundreds of casts
// we should just be careful :)
#pragma warning(disable: 4244 4267)
// disable POSIX deprecation warnings
// these functions are never going away, anyway
#pragma warning(disable: 4996)
#endif
#if defined(_WIN32)
#include <windows.h>
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
static void atomic_store(atomic_int * ptr, LONG val) {
InterlockedExchange(ptr, val);
}
static LONG atomic_load(atomic_int * ptr) {
return InterlockedCompareExchange(ptr, 0, 0);
}
static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
return InterlockedExchangeAdd(ptr, inc);
}
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec));
}
typedef HANDLE pthread_t;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
(void) unused;
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL)
{
return EAGAIN;
}
*out = handle;
return 0;
}
static int pthread_join(pthread_t thread, void * unused) {
(void) unused;
int ret = (int) WaitForSingleObject(thread, INFINITE);
CloseHandle(thread);
return ret;
}
static int sched_yield (void) {
Sleep (0);
return 0;
}
#else
#include <pthread.h>
#include <stdatomic.h>
typedef void * thread_ret_t;
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#ifdef GGML_USE_CPU_HBM
#include <hbwmalloc.h>
#endif
#if defined(__APPLE__)
#include <TargetConditionals.h>
#endif
#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
(!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
#include <sys/wait.h>
void ggml_v3_print_backtrace(void) {
/*
#include <execinfo.h>
#include <dlfcn.h>
void * trace[100];
int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
*/
// backtrack_symbols does not show line numbers, use gdb instead
char attach[32];
snprintf(attach, sizeof(attach), "attach %d", getpid());
int pid = fork();
if (pid == 0) {
execlp("gdb", "gdb", "--batch",
"-ex", "set style enabled on",
"-ex", attach,
"-ex", "bt -frame-info source-and-location",
"-ex", "detach",
"-ex", "quit",
(char *) NULL);
} else {
waitpid(pid, NULL, 0);
}
}
#else
void ggml_v3_print_backtrace(void) {
// platform not supported
}
#endif
/*#define GGML_V3_PERF*/
#define GGML_V3_DEBUG 0
#define GGML_V3_GELU_FP16
#define GGML_V3_GELU_QUICK_FP16
#define GGML_V3_SILU_FP16
// #define GGML_V3_CROSS_ENTROPY_EXP_FP16
// #define GGML_V3_FLASH_ATTN_EXP_FP16
#define GGML_V3_SOFT_MAX_UNROLL 4
#define GGML_V3_VEC_DOT_UNROLL 2
#define GGML_V3_VEC_MAD_UNROLL 32
//
// logging
//
#if (GGML_V3_DEBUG >= 1)
#define GGML_V3_PRINT_DEBUG(...) printf(__VA_ARGS__)
#else
#define GGML_V3_PRINT_DEBUG(...)
#endif
#if (GGML_V3_DEBUG >= 5)
#define GGML_V3_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
#else
#define GGML_V3_PRINT_DEBUG_5(...)
#endif
#if (GGML_V3_DEBUG >= 10)
#define GGML_V3_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
#else
#define GGML_V3_PRINT_DEBUG_10(...)
#endif
#define GGML_V3_PRINT(...) printf(__VA_ARGS__)
//
// end of logging block
//
#ifdef GGML_USE_ACCELERATE
// uncomment to use vDSP for soft max computation
// note: not sure if it is actually faster
//#define GGML_V3_SOFT_MAX_ACCELERATE
#endif
#if defined(_MSC_VER) || defined(__MINGW32__)
#define GGML_V3_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_V3_MEM_ALIGN)
#define GGML_V3_ALIGNED_FREE(ptr) _aligned_free(ptr)
#else
inline static void * ggml_v3_aligned_malloc(size_t size) {
if (size == 0) {
GGML_V3_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_v3_aligned_malloc!\n");
return NULL;
}
void * aligned_memory = NULL;
#ifdef GGML_USE_CPU_HBM
int result = hbw_posix_memalign(&aligned_memory, 16, size);
#elif GGML_USE_METAL
int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
#else
int result = posix_memalign(&aligned_memory, GGML_V3_MEM_ALIGN, size);
#endif
if (result != 0) {
// Handle allocation failure
const char *error_desc = "unknown allocation error";
switch (result) {
case EINVAL:
error_desc = "invalid alignment value";
break;
case ENOMEM:
error_desc = "insufficient memory";
break;
}
GGML_V3_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
return NULL;
}
return aligned_memory;
}
#define GGML_V3_ALIGNED_MALLOC(size) ggml_v3_aligned_malloc(size)
#ifdef GGML_USE_CPU_HBM
#define GGML_V3_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr)
#else
#define GGML_V3_ALIGNED_FREE(ptr) free(ptr)
#endif
#endif
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
#if defined(GGML_USE_ACCELERATE)
#include <Accelerate/Accelerate.h>
#if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
#include "ggml_v3-opencl.h"
#endif
#elif defined(GGML_USE_OPENBLAS)
#if defined(GGML_V3_BLAS_USE_MKL)
#include <mkl.h>
#else
#include <cblas.h>
#endif
#elif defined(GGML_USE_CUDA)
#include "ggml_v3-cuda.h"
#elif defined(GGML_USE_CLBLAST)
#include "ggml_v3-opencl.h"
#endif
// floating point type used to accumulate sums
typedef double ggml_v3_float;
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
//
// global data
//
// precomputed gelu table for f16 (128 KB)
static ggml_v3_fp16_t ggml_v3_table_gelu_f16[1 << 16];
// precomputed quick gelu table for f16 (128 KB)
static ggml_v3_fp16_t ggml_v3_table_gelu_quick_f16[1 << 16];
// precomputed silu table for f16 (128 KB)
static ggml_v3_fp16_t ggml_v3_table_silu_f16[1 << 16];
// precomputed exp table for f16 (128 KB)
static ggml_v3_fp16_t ggml_v3_table_exp_f16[1 << 16];
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
float ggml_v3_table_f32_f16[1 << 16];
// note: do not use these inside ggml.c
// these are meant to be used via the ggml.h API
float ggml_v3_fp16_to_fp32(ggml_v3_fp16_t x) {
return (float) GGML_V3_FP16_TO_FP32(x);
}
ggml_v3_fp16_t ggml_v3_fp32_to_fp16(float x) {
return GGML_V3_FP32_TO_FP16(x);
}
void ggml_v3_fp16_to_fp32_row(const ggml_v3_fp16_t * x, float * y, int n) {
for (int i = 0; i < n; i++) {
y[i] = GGML_V3_FP16_TO_FP32(x[i]);
}
}
void ggml_v3_fp32_to_fp16_row(const float * x, ggml_v3_fp16_t * y, int n) {
int i = 0;
#if defined(__F16C__)
for (; i + 7 < n; i += 8) {
__m256 x_vec = _mm256_loadu_ps(x + i);
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storeu_si128((__m128i *)(y + i), y_vec);
}
for(; i + 3 < n; i += 4) {
__m128 x_vec = _mm_loadu_ps(x + i);
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storel_epi64((__m128i *)(y + i), y_vec);
}
#endif
for (; i < n; i++) {
y[i] = GGML_V3_FP32_TO_FP16(x[i]);
}
}
//
// timing
//
#if defined(_MSC_VER) || defined(__MINGW32__)
static int64_t timer_freq, timer_start;
void ggml_v3_time_init(void) {
LARGE_INTEGER t;
QueryPerformanceFrequency(&t);
timer_freq = t.QuadPart;
// The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
// and the uptime is high enough.
// We subtract the program start time to reduce the likelihood of that happening.
QueryPerformanceCounter(&t);
timer_start = t.QuadPart;
}
int64_t ggml_v3_time_ms(void) {
LARGE_INTEGER t;
QueryPerformanceCounter(&t);
return ((t.QuadPart-timer_start) * 1000) / timer_freq;
}
int64_t ggml_v3_time_us(void) {
LARGE_INTEGER t;
QueryPerformanceCounter(&t);
return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
}
#else
void ggml_v3_time_init(void) {}
int64_t ggml_v3_time_ms(void) {
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
}
int64_t ggml_v3_time_us(void) {
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
}
#endif
int64_t ggml_v3_cycles(void) {
return clock();
}
int64_t ggml_v3_cycles_per_ms(void) {
return CLOCKS_PER_SEC/1000;
}
#ifdef GGML_V3_PERF
#define ggml_v3_perf_time_ms() ggml_v3_time_ms()
#define ggml_v3_perf_time_us() ggml_v3_time_us()
#define ggml_v3_perf_cycles() ggml_v3_cycles()
#define ggml_v3_perf_cycles_per_ms() ggml_v3_cycles_per_ms()
#else
#define ggml_v3_perf_time_ms() 0
#define ggml_v3_perf_time_us() 0
#define ggml_v3_perf_cycles() 0
#define ggml_v3_perf_cycles_per_ms() 0
#endif
//
// cache line
//
#if defined(__cpp_lib_hardware_interference_size)
#define CACHE_LINE_SIZE hardware_destructive_interference_size
#else
#if defined(__POWER9_VECTOR__)
#define CACHE_LINE_SIZE 128
#else
#define CACHE_LINE_SIZE 64
#endif
#endif
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
static void ggml_v3_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
static void ggml_v3_vec_dot_f16(const int n, float * restrict s, ggml_v3_fp16_t * restrict x, ggml_v3_fp16_t * restrict y);
ggml_v3_collect_imatrix_t g_imatrix_collect_v3 = NULL;
void ggml_v3_set_imatrix_collection(ggml_v3_collect_imatrix_t imatrix_collect) {
g_imatrix_collect_v3 = imatrix_collect;
}
static const ggml_v3_type_traits_t type_traits[GGML_V3_TYPE_COUNT] = {
[GGML_V3_TYPE_I8] = {
.type_name = "i8",
.blck_size = 1,
.type_size = sizeof(int8_t),
.is_quantized = false,
},
[GGML_V3_TYPE_I16] = {
.type_name = "i16",
.blck_size = 1,
.type_size = sizeof(int16_t),
.is_quantized = false,
},
[GGML_V3_TYPE_I32] = {
.type_name = "i32",
.blck_size = 1,
.type_size = sizeof(int32_t),
.is_quantized = false,
},
[GGML_V3_TYPE_F32] = {
.type_name = "f32",
.blck_size = 1,
.type_size = sizeof(float),
.is_quantized = false,
.vec_dot = (ggml_v3_vec_dot_t) ggml_v3_vec_dot_f32,
.vec_dot_type = GGML_V3_TYPE_F32,
},
[GGML_V3_TYPE_F16] = {
.type_name = "f16",
.blck_size = 1,
.type_size = sizeof(ggml_v3_fp16_t),
.is_quantized = false,
.to_float = (ggml_v3_to_float_t) ggml_v3_fp16_to_fp32_row,
.from_float = (ggml_v3_from_float_t) ggml_v3_fp32_to_fp16_row,
.from_float_reference = (ggml_v3_from_float_t) ggml_v3_fp32_to_fp16_row,
.vec_dot = (ggml_v3_vec_dot_t) ggml_v3_vec_dot_f16,
.vec_dot_type = GGML_V3_TYPE_F16,
},
[GGML_V3_TYPE_Q4_0] = {
.type_name = "q4_0",
.blck_size = QK4_0,
.type_size = sizeof(block_q4_0),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q4_0,
.from_float = quantize_row_q4_0,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q4_0_reference,
.vec_dot = ggml_v3_vec_dot_q4_0_q8_0,
.vec_dot_type = GGML_V3_TYPE_Q8_0,
},
[GGML_V3_TYPE_Q4_1] = {
.type_name = "q4_1",
.blck_size = QK4_1,
.type_size = sizeof(block_q4_1),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q4_1,
.from_float = quantize_row_q4_1,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q4_1_reference,
.vec_dot = ggml_v3_vec_dot_q4_1_q8_1,
.vec_dot_type = GGML_V3_TYPE_Q8_1,
},
[4] = { // GGML_V3_TYPE_Q4_2
.type_name = "DEPRECATED",
.blck_size = 0,
.type_size = 0,
.is_quantized = false,
.to_float = NULL,
.from_float = NULL,
.from_float_reference = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_V3_TYPE_COUNT,
},
[5] = { // GGML_V3_TYPE_Q4_3
.type_name = "DEPRECATED",
.blck_size = 0,
.type_size = 0,
.is_quantized = false,
.to_float = NULL,
.from_float = NULL,
.from_float_reference = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_V3_TYPE_COUNT,
},
[GGML_V3_TYPE_Q5_0] = {
.type_name = "q5_0",
.blck_size = QK5_0,
.type_size = sizeof(block_q5_0),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q5_0,
.from_float = quantize_row_q5_0,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q5_0_reference,
.vec_dot = ggml_v3_vec_dot_q5_0_q8_0,
.vec_dot_type = GGML_V3_TYPE_Q8_0,
},
[GGML_V3_TYPE_Q5_1] = {
.type_name = "q5_1",
.blck_size = QK5_1,
.type_size = sizeof(block_q5_1),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q5_1,
.from_float = quantize_row_q5_1,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q5_1_reference,
.vec_dot = ggml_v3_vec_dot_q5_1_q8_1,
.vec_dot_type = GGML_V3_TYPE_Q8_1,
},
[GGML_V3_TYPE_Q8_0] = {
.type_name = "q8_0",
.blck_size = QK8_0,
.type_size = sizeof(block_q8_0),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q8_0,
.from_float = quantize_row_q8_0,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q8_0_reference,
.vec_dot = ggml_v3_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_V3_TYPE_Q8_0,
},
[GGML_V3_TYPE_Q8_1] = {
.type_name = "q8_1",
.blck_size = QK8_1,
.type_size = sizeof(block_q8_1),
.is_quantized = true,
.from_float = quantize_row_q8_1,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q8_1_reference,
.vec_dot_type = GGML_V3_TYPE_Q8_1,
},
[GGML_V3_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
.type_size = sizeof(block_q2_K),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q2_K,
.from_float = quantize_row_q2_K,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q2_K_reference,
.vec_dot = ggml_v3_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_V3_TYPE_Q8_K,
},
[GGML_V3_TYPE_Q3_K] = {
.type_name = "q3_K",
.blck_size = QK_K,
.type_size = sizeof(block_q3_K),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q3_K,
.from_float = quantize_row_q3_K,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q3_K_reference,
.vec_dot = ggml_v3_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_V3_TYPE_Q8_K,
},
[GGML_V3_TYPE_Q4_K] = {
.type_name = "q4_K",
.blck_size = QK_K,
.type_size = sizeof(block_q4_K),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q4_K,
.from_float = quantize_row_q4_K,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q4_K_reference,
.vec_dot = ggml_v3_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_V3_TYPE_Q8_K,
},
[GGML_V3_TYPE_Q5_K] = {
.type_name = "q5_K",
.blck_size = QK_K,
.type_size = sizeof(block_q5_K),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q5_K,
.from_float = quantize_row_q5_K,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q5_K_reference,
.vec_dot = ggml_v3_vec_dot_q5_K_q8_K,
.vec_dot_type = GGML_V3_TYPE_Q8_K,
},
[GGML_V3_TYPE_Q6_K] = {
.type_name = "q6_K",
.blck_size = QK_K,
.type_size = sizeof(block_q6_K),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_q6_K,
.from_float = quantize_row_q6_K,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_q6_K_reference,
.vec_dot = ggml_v3_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_V3_TYPE_Q8_K,
},
[GGML_V3_TYPE_IQ2_XXS] = {
.type_name = "iq2_xxs",
.blck_size = QK_K,
.type_size = sizeof(block_iq2_xxs),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_iq2_xxs,
.from_float = quantize_row_iq2_xxs,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_iq2_xxs_reference,
.vec_dot = ggml_v3_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_V3_TYPE_Q8_K,
},
[GGML_V3_TYPE_IQ2_XS] = {
.type_name = "iq2_xs",
.blck_size = QK_K,
.type_size = sizeof(block_iq2_xs),
.is_quantized = true,
.to_float = (ggml_v3_to_float_t) dequantize_row_iq2_xs,
.from_float = quantize_row_iq2_xs,
.from_float_reference = (ggml_v3_from_float_t) quantize_row_iq2_xs_reference,
.vec_dot = ggml_v3_vec_dot_iq2_xs_q8_K,
.vec_dot_type = GGML_V3_TYPE_Q8_K,
},
[GGML_V3_TYPE_Q8_K] = {
.type_name = "q8_K",
.blck_size = QK_K,
.type_size = sizeof(block_q8_K),
.is_quantized = true,
.from_float = quantize_row_q8_K,
}
};
// For internal test use
ggml_v3_type_traits_t ggml_v3_internal_get_type_traits(enum ggml_v3_type type) {
GGML_V3_ASSERT(type < GGML_V3_TYPE_COUNT);
return type_traits[type];
}
//
// simd mappings
//
#if defined(__ARM_NEON)
#if !defined(__aarch64__)
// 64-bit compatibility
inline static float vaddvq_f32(float32x4_t v) {
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
}
#endif
#endif
// we define a common set of C macros which map to specific intrinsics based on the current architecture
// we then implement the fundamental computation operations below using only these macros
// adding support for new architectures requires to define the corresponding SIMD macros
//
// GGML_V3_F32_STEP / GGML_V3_F16_STEP
// number of elements to process in a single step
//
// GGML_V3_F32_EPR / GGML_V3_F16_EPR
// number of elements to fit in a single register
//
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
#define GGML_V3_SIMD
// F32 NEON
#define GGML_V3_F32_STEP 16
#define GGML_V3_F32_EPR 4
#define GGML_V3_F32x4 float32x4_t
#define GGML_V3_F32x4_ZERO vdupq_n_f32(0.0f)
#define GGML_V3_F32x4_SET1(x) vdupq_n_f32(x)
#define GGML_V3_F32x4_LOAD vld1q_f32
#define GGML_V3_F32x4_STORE vst1q_f32
#define GGML_V3_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
#define GGML_V3_F32x4_ADD vaddq_f32
#define GGML_V3_F32x4_MUL vmulq_f32
#define GGML_V3_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
#define GGML_V3_F32x4_REDUCE(res, x) \
{ \
int offset = GGML_V3_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vaddq_f32(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vaddq_f32(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vaddq_f32(x[i], x[offset+i]); \
} \
res = GGML_V3_F32x4_REDUCE_ONE(x[0]); \
}
#define GGML_V3_F32_VEC GGML_V3_F32x4
#define GGML_V3_F32_VEC_ZERO GGML_V3_F32x4_ZERO
#define GGML_V3_F32_VEC_SET1 GGML_V3_F32x4_SET1
#define GGML_V3_F32_VEC_LOAD GGML_V3_F32x4_LOAD
#define GGML_V3_F32_VEC_STORE GGML_V3_F32x4_STORE
#define GGML_V3_F32_VEC_FMA GGML_V3_F32x4_FMA
#define GGML_V3_F32_VEC_ADD GGML_V3_F32x4_ADD
#define GGML_V3_F32_VEC_MUL GGML_V3_F32x4_MUL
#define GGML_V3_F32_VEC_REDUCE GGML_V3_F32x4_REDUCE
// F16 NEON
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
#define GGML_V3_F16_STEP 32
#define GGML_V3_F16_EPR 8
#define GGML_V3_F16x8 float16x8_t
#define GGML_V3_F16x8_ZERO vdupq_n_f16(0.0f)
#define GGML_V3_F16x8_SET1(x) vdupq_n_f16(x)
#define GGML_V3_F16x8_LOAD vld1q_f16
#define GGML_V3_F16x8_STORE vst1q_f16
#define GGML_V3_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
#define GGML_V3_F16x8_ADD vaddq_f16
#define GGML_V3_F16x8_MUL vmulq_f16
#define GGML_V3_F16x8_REDUCE(res, x) \
do { \
int offset = GGML_V3_F16_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vaddq_f16(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vaddq_f16(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vaddq_f16(x[i], x[offset+i]); \
} \
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
res = (ggml_v3_float) vaddvq_f32(vaddq_f32(t0, t1)); \
} while (0)
#define GGML_V3_F16_VEC GGML_V3_F16x8
#define GGML_V3_F16_VEC_ZERO GGML_V3_F16x8_ZERO
#define GGML_V3_F16_VEC_SET1 GGML_V3_F16x8_SET1
#define GGML_V3_F16_VEC_LOAD(p, i) GGML_V3_F16x8_LOAD(p)
#define GGML_V3_F16_VEC_STORE(p, r, i) GGML_V3_F16x8_STORE(p, r[i])
#define GGML_V3_F16_VEC_FMA GGML_V3_F16x8_FMA
#define GGML_V3_F16_VEC_ADD GGML_V3_F16x8_ADD
#define GGML_V3_F16_VEC_MUL GGML_V3_F16x8_MUL
#define GGML_V3_F16_VEC_REDUCE GGML_V3_F16x8_REDUCE
#else
// if FP16 vector arithmetic is not supported, we use FP32 instead
// and take advantage of the vcvt_ functions to convert to/from FP16
#define GGML_V3_F16_STEP 16
#define GGML_V3_F16_EPR 4
#define GGML_V3_F32Cx4 float32x4_t
#define GGML_V3_F32Cx4_ZERO vdupq_n_f32(0.0f)
#define GGML_V3_F32Cx4_SET1(x) vdupq_n_f32(x)
#define GGML_V3_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x))
#define GGML_V3_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
#define GGML_V3_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
#define GGML_V3_F32Cx4_ADD vaddq_f32
#define GGML_V3_F32Cx4_MUL vmulq_f32
#define GGML_V3_F32Cx4_REDUCE GGML_V3_F32x4_REDUCE
#define GGML_V3_F16_VEC GGML_V3_F32Cx4
#define GGML_V3_F16_VEC_ZERO GGML_V3_F32Cx4_ZERO
#define GGML_V3_F16_VEC_SET1 GGML_V3_F32Cx4_SET1
#define GGML_V3_F16_VEC_LOAD(p, i) GGML_V3_F32Cx4_LOAD(p)
#define GGML_V3_F16_VEC_STORE(p, r, i) GGML_V3_F32Cx4_STORE(p, r[i])
#define GGML_V3_F16_VEC_FMA GGML_V3_F32Cx4_FMA
#define GGML_V3_F16_VEC_ADD GGML_V3_F32Cx4_ADD
#define GGML_V3_F16_VEC_MUL GGML_V3_F32Cx4_MUL
#define GGML_V3_F16_VEC_REDUCE GGML_V3_F32Cx4_REDUCE
#endif
#elif defined(__AVX__)
#define GGML_V3_SIMD
// F32 AVX
#define GGML_V3_F32_STEP 32
#define GGML_V3_F32_EPR 8
#define GGML_V3_F32x8 __m256
#define GGML_V3_F32x8_ZERO _mm256_setzero_ps()
#define GGML_V3_F32x8_SET1(x) _mm256_set1_ps(x)
#define GGML_V3_F32x8_LOAD _mm256_loadu_ps
#define GGML_V3_F32x8_STORE _mm256_storeu_ps
#if defined(__FMA__)
#define GGML_V3_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
#else
#define GGML_V3_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
#endif
#define GGML_V3_F32x8_ADD _mm256_add_ps
#define GGML_V3_F32x8_MUL _mm256_mul_ps
#define GGML_V3_F32x8_REDUCE(res, x) \
do { \
int offset = GGML_V3_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
} \
const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
_mm256_extractf128_ps(x[0], 1)); \
const __m128 t1 = _mm_hadd_ps(t0, t0); \
res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
} while (0)
// TODO: is this optimal ?
#define GGML_V3_F32_VEC GGML_V3_F32x8
#define GGML_V3_F32_VEC_ZERO GGML_V3_F32x8_ZERO
#define GGML_V3_F32_VEC_SET1 GGML_V3_F32x8_SET1
#define GGML_V3_F32_VEC_LOAD GGML_V3_F32x8_LOAD
#define GGML_V3_F32_VEC_STORE GGML_V3_F32x8_STORE
#define GGML_V3_F32_VEC_FMA GGML_V3_F32x8_FMA
#define GGML_V3_F32_VEC_ADD GGML_V3_F32x8_ADD
#define GGML_V3_F32_VEC_MUL GGML_V3_F32x8_MUL
#define GGML_V3_F32_VEC_REDUCE GGML_V3_F32x8_REDUCE
// F16 AVX
#define GGML_V3_F16_STEP 32
#define GGML_V3_F16_EPR 8
// F16 arithmetic is not supported by AVX, so we use F32 instead
#define GGML_V3_F32Cx8 __m256
#define GGML_V3_F32Cx8_ZERO _mm256_setzero_ps()
#define GGML_V3_F32Cx8_SET1(x) _mm256_set1_ps(x)
#if defined(__F16C__)
// the _mm256_cvt intrinsics require F16C
#define GGML_V3_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
#define GGML_V3_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
#else
static inline __m256 __avx_f32cx8_load(ggml_v3_fp16_t *x) {
float tmp[8];
for (int i = 0; i < 8; i++) {
tmp[i] = GGML_V3_FP16_TO_FP32(x[i]);
}
return _mm256_loadu_ps(tmp);
}
static inline void __avx_f32cx8_store(ggml_v3_fp16_t *x, __m256 y) {
float arr[8];
_mm256_storeu_ps(arr, y);
for (int i = 0; i < 8; i++)
x[i] = GGML_V3_FP32_TO_FP16(arr[i]);
}
#define GGML_V3_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
#define GGML_V3_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
#endif
#define GGML_V3_F32Cx8_FMA GGML_V3_F32x8_FMA
#define GGML_V3_F32Cx8_ADD _mm256_add_ps
#define GGML_V3_F32Cx8_MUL _mm256_mul_ps
#define GGML_V3_F32Cx8_REDUCE GGML_V3_F32x8_REDUCE
#define GGML_V3_F16_VEC GGML_V3_F32Cx8
#define GGML_V3_F16_VEC_ZERO GGML_V3_F32Cx8_ZERO
#define GGML_V3_F16_VEC_SET1 GGML_V3_F32Cx8_SET1
#define GGML_V3_F16_VEC_LOAD(p, i) GGML_V3_F32Cx8_LOAD(p)
#define GGML_V3_F16_VEC_STORE(p, r, i) GGML_V3_F32Cx8_STORE(p, r[i])
#define GGML_V3_F16_VEC_FMA GGML_V3_F32Cx8_FMA
#define GGML_V3_F16_VEC_ADD GGML_V3_F32Cx8_ADD
#define GGML_V3_F16_VEC_MUL GGML_V3_F32Cx8_MUL
#define GGML_V3_F16_VEC_REDUCE GGML_V3_F32Cx8_REDUCE
#elif defined(__POWER9_VECTOR__)
#define GGML_V3_SIMD
// F32 POWER9
#define GGML_V3_F32_STEP 32
#define GGML_V3_F32_EPR 4
#define GGML_V3_F32x4 vector float
#define GGML_V3_F32x4_ZERO 0.0f
#define GGML_V3_F32x4_SET1 vec_splats
#define GGML_V3_F32x4_LOAD(p) vec_xl(0, p)
#define GGML_V3_F32x4_STORE(p, r) vec_xst(r, 0, p)
#define GGML_V3_F32x4_FMA(a, b, c) vec_madd(b, c, a)
#define GGML_V3_F32x4_ADD vec_add
#define GGML_V3_F32x4_MUL vec_mul
#define GGML_V3_F32x4_REDUCE(res, x) \
{ \
int offset = GGML_V3_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset+i]); \
} \
res = vec_extract(x[0], 0) + \
vec_extract(x[0], 1) + \
vec_extract(x[0], 2) + \
vec_extract(x[0], 3); \
}
#define GGML_V3_F32_VEC GGML_V3_F32x4
#define GGML_V3_F32_VEC_ZERO GGML_V3_F32x4_ZERO
#define GGML_V3_F32_VEC_SET1 GGML_V3_F32x4_SET1
#define GGML_V3_F32_VEC_LOAD GGML_V3_F32x4_LOAD
#define GGML_V3_F32_VEC_STORE GGML_V3_F32x4_STORE
#define GGML_V3_F32_VEC_FMA GGML_V3_F32x4_FMA
#define GGML_V3_F32_VEC_ADD GGML_V3_F32x4_ADD
#define GGML_V3_F32_VEC_MUL GGML_V3_F32x4_MUL
#define GGML_V3_F32_VEC_REDUCE GGML_V3_F32x4_REDUCE
// F16 POWER9
#define GGML_V3_F16_STEP GGML_V3_F32_STEP
#define GGML_V3_F16_EPR GGML_V3_F32_EPR
#define GGML_V3_F16_VEC GGML_V3_F32x4
#define GGML_V3_F16_VEC_ZERO GGML_V3_F32x4_ZERO
#define GGML_V3_F16_VEC_SET1 GGML_V3_F32x4_SET1
#define GGML_V3_F16_VEC_FMA GGML_V3_F32x4_FMA
#define GGML_V3_F16_VEC_REDUCE GGML_V3_F32x4_REDUCE
// Use vec_xl, not vec_ld, in case the load address is not aligned.
#define GGML_V3_F16_VEC_LOAD(p, i) (i & 0x1) ? \
vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_V3_F16_EPR)) : \
vec_extract_fp32_from_shortl(vec_xl(0, p))
#define GGML_V3_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
#define GGML_V3_F16_VEC_STORE(p, r, i) \
if (i & 0x1) \
vec_xst(vec_pack_to_short_fp32(r[i - GGML_V3_ENDIAN_BYTE(1)], \
r[i - GGML_V3_ENDIAN_BYTE(0)]), \
0, p - GGML_V3_F16_EPR)
#elif defined(__wasm_simd128__)
#define GGML_V3_SIMD
// F32 WASM
#define GGML_V3_F32_STEP 16
#define GGML_V3_F32_EPR 4
#define GGML_V3_F32x4 v128_t
#define GGML_V3_F32x4_ZERO wasm_f32x4_splat(0.0f)
#define GGML_V3_F32x4_SET1(x) wasm_f32x4_splat(x)
#define GGML_V3_F32x4_LOAD wasm_v128_load
#define GGML_V3_F32x4_STORE wasm_v128_store
#define GGML_V3_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
#define GGML_V3_F32x4_ADD wasm_f32x4_add
#define GGML_V3_F32x4_MUL wasm_f32x4_mul
#define GGML_V3_F32x4_REDUCE(res, x) \
{ \
int offset = GGML_V3_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
} \
res = wasm_f32x4_extract_lane(x[0], 0) + \
wasm_f32x4_extract_lane(x[0], 1) + \
wasm_f32x4_extract_lane(x[0], 2) + \
wasm_f32x4_extract_lane(x[0], 3); \
}
#define GGML_V3_F32_VEC GGML_V3_F32x4
#define GGML_V3_F32_VEC_ZERO GGML_V3_F32x4_ZERO
#define GGML_V3_F32_VEC_SET1 GGML_V3_F32x4_SET1
#define GGML_V3_F32_VEC_LOAD GGML_V3_F32x4_LOAD
#define GGML_V3_F32_VEC_STORE GGML_V3_F32x4_STORE
#define GGML_V3_F32_VEC_FMA GGML_V3_F32x4_FMA
#define GGML_V3_F32_VEC_ADD GGML_V3_F32x4_ADD
#define GGML_V3_F32_VEC_MUL GGML_V3_F32x4_MUL
#define GGML_V3_F32_VEC_REDUCE GGML_V3_F32x4_REDUCE
// F16 WASM
#define GGML_V3_F16_STEP 16
#define GGML_V3_F16_EPR 4
inline static v128_t __wasm_f16x4_load(const ggml_v3_fp16_t * p) {
float tmp[4];
tmp[0] = GGML_V3_FP16_TO_FP32(p[0]);
tmp[1] = GGML_V3_FP16_TO_FP32(p[1]);
tmp[2] = GGML_V3_FP16_TO_FP32(p[2]);
tmp[3] = GGML_V3_FP16_TO_FP32(p[3]);
return wasm_v128_load(tmp);
}
inline static void __wasm_f16x4_store(ggml_v3_fp16_t * p, v128_t x) {
float tmp[4];
wasm_v128_store(tmp, x);
p[0] = GGML_V3_FP32_TO_FP16(tmp[0]);
p[1] = GGML_V3_FP32_TO_FP16(tmp[1]);
p[2] = GGML_V3_FP32_TO_FP16(tmp[2]);
p[3] = GGML_V3_FP32_TO_FP16(tmp[3]);
}
#define GGML_V3_F16x4 v128_t
#define GGML_V3_F16x4_ZERO wasm_f32x4_splat(0.0f)
#define GGML_V3_F16x4_SET1(x) wasm_f32x4_splat(x)
#define GGML_V3_F16x4_LOAD(x) __wasm_f16x4_load(x)
#define GGML_V3_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
#define GGML_V3_F16x4_FMA GGML_V3_F32x4_FMA
#define GGML_V3_F16x4_ADD wasm_f32x4_add
#define GGML_V3_F16x4_MUL wasm_f32x4_mul
#define GGML_V3_F16x4_REDUCE(res, x) \
{ \
int offset = GGML_V3_F16_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
} \
res = wasm_f32x4_extract_lane(x[0], 0) + \
wasm_f32x4_extract_lane(x[0], 1) + \
wasm_f32x4_extract_lane(x[0], 2) + \
wasm_f32x4_extract_lane(x[0], 3); \
}
#define GGML_V3_F16_VEC GGML_V3_F16x4
#define GGML_V3_F16_VEC_ZERO GGML_V3_F16x4_ZERO
#define GGML_V3_F16_VEC_SET1 GGML_V3_F16x4_SET1
#define GGML_V3_F16_VEC_LOAD(p, i) GGML_V3_F16x4_LOAD(p)
#define GGML_V3_F16_VEC_STORE(p, r, i) GGML_V3_F16x4_STORE(p, r[i])
#define GGML_V3_F16_VEC_FMA GGML_V3_F16x4_FMA
#define GGML_V3_F16_VEC_ADD GGML_V3_F16x4_ADD
#define GGML_V3_F16_VEC_MUL GGML_V3_F16x4_MUL
#define GGML_V3_F16_VEC_REDUCE GGML_V3_F16x4_REDUCE
#elif defined(__SSE3__)
#define GGML_V3_SIMD
// F32 SSE
#define GGML_V3_F32_STEP 32
#define GGML_V3_F32_EPR 4
#define GGML_V3_F32x4 __m128
#define GGML_V3_F32x4_ZERO _mm_setzero_ps()
#define GGML_V3_F32x4_SET1(x) _mm_set1_ps(x)
#define GGML_V3_F32x4_LOAD _mm_loadu_ps
#define GGML_V3_F32x4_STORE _mm_storeu_ps
#if defined(__FMA__)
// TODO: Does this work?
#define GGML_V3_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
#else
#define GGML_V3_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
#endif
#define GGML_V3_F32x4_ADD _mm_add_ps
#define GGML_V3_F32x4_MUL _mm_mul_ps
#define GGML_V3_F32x4_REDUCE(res, x) \
{ \
int offset = GGML_V3_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm_add_ps(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm_add_ps(x[i], x[offset+i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = _mm_add_ps(x[i], x[offset+i]); \
} \
const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
}
// TODO: is this optimal ?
#define GGML_V3_F32_VEC GGML_V3_F32x4
#define GGML_V3_F32_VEC_ZERO GGML_V3_F32x4_ZERO
#define GGML_V3_F32_VEC_SET1 GGML_V3_F32x4_SET1
#define GGML_V3_F32_VEC_LOAD GGML_V3_F32x4_LOAD
#define GGML_V3_F32_VEC_STORE GGML_V3_F32x4_STORE
#define GGML_V3_F32_VEC_FMA GGML_V3_F32x4_FMA
#define GGML_V3_F32_VEC_ADD GGML_V3_F32x4_ADD
#define GGML_V3_F32_VEC_MUL GGML_V3_F32x4_MUL
#define GGML_V3_F32_VEC_REDUCE GGML_V3_F32x4_REDUCE
// F16 SSE
#define GGML_V3_F16_STEP 32
#define GGML_V3_F16_EPR 4
static inline __m128 __sse_f16x4_load(ggml_v3_fp16_t *x) {
float tmp[4];
tmp[0] = GGML_V3_FP16_TO_FP32(x[0]);
tmp[1] = GGML_V3_FP16_TO_FP32(x[1]);
tmp[2] = GGML_V3_FP16_TO_FP32(x[2]);
tmp[3] = GGML_V3_FP16_TO_FP32(x[3]);
return _mm_loadu_ps(tmp);
}
static inline void __sse_f16x4_store(ggml_v3_fp16_t *x, __m128 y) {
float arr[4];
_mm_storeu_ps(arr, y);
x[0] = GGML_V3_FP32_TO_FP16(arr[0]);
x[1] = GGML_V3_FP32_TO_FP16(arr[1]);
x[2] = GGML_V3_FP32_TO_FP16(arr[2]);
x[3] = GGML_V3_FP32_TO_FP16(arr[3]);
}
#define GGML_V3_F32Cx4 __m128
#define GGML_V3_F32Cx4_ZERO _mm_setzero_ps()
#define GGML_V3_F32Cx4_SET1(x) _mm_set1_ps(x)
#define GGML_V3_F32Cx4_LOAD(x) __sse_f16x4_load(x)
#define GGML_V3_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
#define GGML_V3_F32Cx4_FMA GGML_V3_F32x4_FMA
#define GGML_V3_F32Cx4_ADD _mm_add_ps
#define GGML_V3_F32Cx4_MUL _mm_mul_ps
#define GGML_V3_F32Cx4_REDUCE GGML_V3_F32x4_REDUCE
#define GGML_V3_F16_VEC GGML_V3_F32Cx4
#define GGML_V3_F16_VEC_ZERO GGML_V3_F32Cx4_ZERO
#define GGML_V3_F16_VEC_SET1 GGML_V3_F32Cx4_SET1
#define GGML_V3_F16_VEC_LOAD(p, i) GGML_V3_F32Cx4_LOAD(p)
#define GGML_V3_F16_VEC_STORE(p, r, i) GGML_V3_F32Cx4_STORE(p, r[i])
#define GGML_V3_F16_VEC_FMA GGML_V3_F32Cx4_FMA
#define GGML_V3_F16_VEC_ADD GGML_V3_F32Cx4_ADD
#define GGML_V3_F16_VEC_MUL GGML_V3_F32Cx4_MUL
#define GGML_V3_F16_VEC_REDUCE GGML_V3_F32Cx4_REDUCE
#endif
// GGML_V3_F32_ARR / GGML_V3_F16_ARR
// number of registers to use per step
#ifdef GGML_V3_SIMD
#define GGML_V3_F32_ARR (GGML_V3_F32_STEP/GGML_V3_F32_EPR)
#define GGML_V3_F16_ARR (GGML_V3_F16_STEP/GGML_V3_F16_EPR)
#endif
//
// fundamental operations
//
inline static void ggml_v3_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_v3_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_v3_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_v3_vec_set_f16(const int n, ggml_v3_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_v3_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
inline static void ggml_v3_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
inline static void ggml_v3_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
inline static void ggml_v3_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
inline static void ggml_v3_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
inline static void ggml_v3_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_v3_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
inline static void ggml_v3_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
inline static void ggml_v3_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_v3_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
static void ggml_v3_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
#ifdef GGML_V3_SIMD
float sumf = 0.0f;
const int np = (n & ~(GGML_V3_F32_STEP - 1));
GGML_V3_F32_VEC sum[GGML_V3_F32_ARR] = { GGML_V3_F32_VEC_ZERO };
GGML_V3_F32_VEC ax[GGML_V3_F32_ARR];
GGML_V3_F32_VEC ay[GGML_V3_F32_ARR];
for (int i = 0; i < np; i += GGML_V3_F32_STEP) {
for (int j = 0; j < GGML_V3_F32_ARR; j++) {
ax[j] = GGML_V3_F32_VEC_LOAD(x + i + j*GGML_V3_F32_EPR);
ay[j] = GGML_V3_F32_VEC_LOAD(y + i + j*GGML_V3_F32_EPR);
sum[j] = GGML_V3_F32_VEC_FMA(sum[j], ax[j], ay[j]);
}
}
// reduce sum0..sum3 to sum0
GGML_V3_F32_VEC_REDUCE(sumf, sum);
// leftovers
for (int i = np; i < n; ++i) {
sumf += x[i]*y[i];
}
#else
// scalar
ggml_v3_float sumf = 0.0;
for (int i = 0; i < n; ++i) {
sumf += (ggml_v3_float)(x[i]*y[i]);
}
#endif
*s = sumf;
}
static void ggml_v3_vec_dot_f16(const int n, float * restrict s, ggml_v3_fp16_t * restrict x, ggml_v3_fp16_t * restrict y) {
ggml_v3_float sumf = 0.0;
#if defined(GGML_V3_SIMD)
const int np = (n & ~(GGML_V3_F16_STEP - 1));
GGML_V3_F16_VEC sum[GGML_V3_F16_ARR] = { GGML_V3_F16_VEC_ZERO };
GGML_V3_F16_VEC ax[GGML_V3_F16_ARR];
GGML_V3_F16_VEC ay[GGML_V3_F16_ARR];
for (int i = 0; i < np; i += GGML_V3_F16_STEP) {
for (int j = 0; j < GGML_V3_F16_ARR; j++) {
ax[j] = GGML_V3_F16_VEC_LOAD(x + i + j*GGML_V3_F16_EPR, j);
ay[j] = GGML_V3_F16_VEC_LOAD(y + i + j*GGML_V3_F16_EPR, j);
sum[j] = GGML_V3_F16_VEC_FMA(sum[j], ax[j], ay[j]);
}
}
// reduce sum0..sum3 to sum0
GGML_V3_F16_VEC_REDUCE(sumf, sum);
// leftovers
for (int i = np; i < n; ++i) {
sumf += (ggml_v3_float)(GGML_V3_FP16_TO_FP32(x[i])*GGML_V3_FP16_TO_FP32(y[i]));
}
#else
for (int i = 0; i < n; ++i) {
sumf += (ggml_v3_float)(GGML_V3_FP16_TO_FP32(x[i])*GGML_V3_FP16_TO_FP32(y[i]));
}
#endif
*s = sumf;
}
// compute GGML_V3_VEC_DOT_UNROLL dot products at once
// xs - x row stride in bytes
inline static void ggml_v3_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_v3_fp16_t * restrict y) {
ggml_v3_float sumf[GGML_V3_VEC_DOT_UNROLL] = { 0.0 };
ggml_v3_fp16_t * restrict x[GGML_V3_VEC_DOT_UNROLL];
for (int i = 0; i < GGML_V3_VEC_DOT_UNROLL; ++i) {
x[i] = (ggml_v3_fp16_t *) ((char *) xv + i*xs);
}
#if defined(GGML_V3_SIMD)
const int np = (n & ~(GGML_V3_F16_STEP - 1));
GGML_V3_F16_VEC sum[GGML_V3_VEC_DOT_UNROLL][GGML_V3_F16_ARR] = { { GGML_V3_F16_VEC_ZERO } };
GGML_V3_F16_VEC ax[GGML_V3_F16_ARR];
GGML_V3_F16_VEC ay[GGML_V3_F16_ARR];
for (int i = 0; i < np; i += GGML_V3_F16_STEP) {
for (int j = 0; j < GGML_V3_F16_ARR; j++) {
ay[j] = GGML_V3_F16_VEC_LOAD(y + i + j*GGML_V3_F16_EPR, j);
for (int k = 0; k < GGML_V3_VEC_DOT_UNROLL; ++k) {
ax[j] = GGML_V3_F16_VEC_LOAD(x[k] + i + j*GGML_V3_F16_EPR, j);
sum[k][j] = GGML_V3_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
}
}
}
// reduce sum0..sum3 to sum0
for (int k = 0; k < GGML_V3_VEC_DOT_UNROLL; ++k) {
GGML_V3_F16_VEC_REDUCE(sumf[k], sum[k]);
}
// leftovers
for (int i = np; i < n; ++i) {
for (int j = 0; j < GGML_V3_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_v3_float)(GGML_V3_FP16_TO_FP32(x[j][i])*GGML_V3_FP16_TO_FP32(y[i]));
}
}
#else
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_V3_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_v3_float)(GGML_V3_FP16_TO_FP32(x[j][i])*GGML_V3_FP16_TO_FP32(y[i]));
}
}
#endif
for (int i = 0; i < GGML_V3_VEC_DOT_UNROLL; ++i) {
s[i] = sumf[i];
}
}
inline static void ggml_v3_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
#if defined(GGML_V3_SIMD)
const int np = (n & ~(GGML_V3_F32_STEP - 1));
GGML_V3_F32_VEC vx = GGML_V3_F32_VEC_SET1(v);
GGML_V3_F32_VEC ax[GGML_V3_F32_ARR];
GGML_V3_F32_VEC ay[GGML_V3_F32_ARR];
for (int i = 0; i < np; i += GGML_V3_F32_STEP) {
for (int j = 0; j < GGML_V3_F32_ARR; j++) {
ax[j] = GGML_V3_F32_VEC_LOAD(x + i + j*GGML_V3_F32_EPR);
ay[j] = GGML_V3_F32_VEC_LOAD(y + i + j*GGML_V3_F32_EPR);
ay[j] = GGML_V3_F32_VEC_FMA(ay[j], ax[j], vx);
GGML_V3_F32_VEC_STORE(y + i + j*GGML_V3_F32_EPR, ay[j]);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] += x[i]*v;
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] += x[i]*v;
}
#endif
}
// xs and vs are byte strides of x and v
inline static void ggml_v3_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
const float * restrict x[GGML_V3_VEC_MAD_UNROLL];
const float * restrict v[GGML_V3_VEC_MAD_UNROLL];
for (int i = 0; i < GGML_V3_VEC_MAD_UNROLL; ++i) {
x[i] = (const float *) ((const char *) xv + i*xs);
v[i] = (const float *) ((const char *) vv + i*vs);
}
#if defined(GGML_V3_SIMD)
const int np = (n & ~(GGML_V3_F32_STEP - 1));
GGML_V3_F32_VEC vx[GGML_V3_VEC_MAD_UNROLL];
for (int k = 0; k < GGML_V3_VEC_MAD_UNROLL; ++k) {
vx[k] = GGML_V3_F32_VEC_SET1(v[k][0]);
}
GGML_V3_F32_VEC ax[GGML_V3_VEC_MAD_UNROLL][GGML_V3_F32_ARR];
GGML_V3_F32_VEC ay[GGML_V3_F32_ARR];
for (int i = 0; i < np; i += GGML_V3_F32_STEP) {
for (int j = 0; j < GGML_V3_F32_ARR; j++) {
ay[j] = GGML_V3_F32_VEC_LOAD(y + i + j*GGML_V3_F32_EPR);
for (int k = 0; k < GGML_V3_VEC_MAD_UNROLL; ++k) {
ax[k][j] = GGML_V3_F32_VEC_LOAD(x[k] + i + j*GGML_V3_F32_EPR);
ay[j] = GGML_V3_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
}
GGML_V3_F32_VEC_STORE(y + i + j*GGML_V3_F32_EPR, ay[j]);
}
}
// leftovers
for (int k = 0; k < GGML_V3_VEC_MAD_UNROLL; ++k) {
for (int i = np; i < n; ++i) {
y[i] += x[k][i]*v[k][0];
}
}
#else
// scalar
for (int k = 0; k < GGML_V3_VEC_MAD_UNROLL; ++k) {
for (int i = 0; i < n; ++i) {
y[i] += x[k][i]*v[k][0];
}
}
#endif
}
//inline static void ggml_v3_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_v3_vec_scale_f32(const int n, float * y, const float v) {
#if defined(GGML_USE_ACCELERATE)
vDSP_vsmul(y, 1, &v, y, 1, n);
#elif defined(GGML_V3_SIMD)
const int np = (n & ~(GGML_V3_F32_STEP - 1));
GGML_V3_F32_VEC vx = GGML_V3_F32_VEC_SET1(v);
GGML_V3_F32_VEC ay[GGML_V3_F32_ARR];
for (int i = 0; i < np; i += GGML_V3_F32_STEP) {
for (int j = 0; j < GGML_V3_F32_ARR; j++) {
ay[j] = GGML_V3_F32_VEC_LOAD(y + i + j*GGML_V3_F32_EPR);
ay[j] = GGML_V3_F32_VEC_MUL(ay[j], vx);
GGML_V3_F32_VEC_STORE(y + i + j*GGML_V3_F32_EPR, ay[j]);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] *= v;
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] *= v;
}
#endif
}
inline static void ggml_v3_vec_norm_f32 (const int n, float * s, const float * x) { ggml_v3_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
inline static void ggml_v3_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_v3_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
inline static void ggml_v3_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
inline static void ggml_v3_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
inline static void ggml_v3_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
inline static void ggml_v3_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
inline static void ggml_v3_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
inline static void ggml_v3_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
inline static void ggml_v3_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
inline static void ggml_v3_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
inline static float ggml_v3_gelu_f32(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
inline static void ggml_v3_vec_gelu_f16(const int n, ggml_v3_fp16_t * y, const ggml_v3_fp16_t * x) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
y[i] = ggml_v3_table_gelu_f16[i16[i]];
}
}
#ifdef GGML_V3_GELU_FP16
inline static void ggml_v3_vec_gelu_f32(const int n, float * y, const float * x) {
uint16_t t;
for (int i = 0; i < n; ++i) {
ggml_v3_fp16_t fp16 = GGML_V3_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_V3_FP16_TO_FP32(ggml_v3_table_gelu_f16[t]);
}
}
#else
inline static void ggml_v3_vec_gelu_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_v3_gelu_f32(x[i]);
}
}
#endif
inline static float ggml_v3_gelu_quick_f32(float x) {
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
}
//inline static void ggml_v3_vec_gelu_quick_f16(const int n, ggml_v3_fp16_t * y, const ggml_v3_fp16_t * x) {
// const uint16_t * i16 = (const uint16_t *) x;
// for (int i = 0; i < n; ++i) {
// y[i] = ggml_v3_table_gelu_quick_f16[i16[i]];
// }
//}
#ifdef GGML_V3_GELU_QUICK_FP16
inline static void ggml_v3_vec_gelu_quick_f32(const int n, float * y, const float * x) {
uint16_t t;
for (int i = 0; i < n; ++i) {
ggml_v3_fp16_t fp16 = GGML_V3_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_V3_FP16_TO_FP32(ggml_v3_table_gelu_quick_f16[t]);
}
}
#else
inline static void ggml_v3_vec_gelu_quick_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_v3_gelu_quick_f32(x[i]);
}
}
#endif
// Sigmoid Linear Unit (SiLU) function
inline static float ggml_v3_silu_f32(float x) {
return x/(1.0f + expf(-x));
}
//inline static void ggml_v3_vec_silu_f16(const int n, ggml_v3_fp16_t * y, const ggml_v3_fp16_t * x) {
// const uint16_t * i16 = (const uint16_t *) x;
// for (int i = 0; i < n; ++i) {
// y[i] = ggml_v3_table_silu_f16[i16[i]];
// }
//}
#ifdef GGML_V3_SILU_FP16
inline static void ggml_v3_vec_silu_f32(const int n, float * y, const float * x) {
uint16_t t;
for (int i = 0; i < n; ++i) {
ggml_v3_fp16_t fp16 = GGML_V3_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_V3_FP16_TO_FP32(ggml_v3_table_silu_f16[t]);
}
}
#else
inline static void ggml_v3_vec_silu_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_v3_silu_f32(x[i]);
}
}
#endif
inline static float ggml_v3_silu_backward_f32(float x, float dy) {
const float s = 1.0f/(1.0f + expf(-x));
return dy*s*(1.0f + x*(1.0f - s));
}
#ifdef GGML_V3_SILU_FP16
inline static void ggml_v3_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
for (int i = 0; i < n; ++i) {
// we did not use x[i] to compute forward silu but its f16 equivalent
// take derivative at f16 of x[i]:
ggml_v3_fp16_t fp16 = GGML_V3_FP32_TO_FP16(x[i]);
float usedx = GGML_V3_FP16_TO_FP32(fp16);
dx[i] = ggml_v3_silu_backward_f32(usedx, dy[i]);
}
}
#else
inline static void ggml_v3_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
for (int i = 0; i < n; ++i) {
dx[i] = ggml_v3_silu_backward_f32(x[i], dy[i]);
}
}
#endif
inline static void ggml_v3_vec_sum_f32(const int n, float * s, const float * x) {
#ifndef GGML_USE_ACCELERATE
ggml_v3_float sum = 0.0;
for (int i = 0; i < n; ++i) {
sum += (ggml_v3_float)x[i];
}
*s = sum;
#else
vDSP_sve(x, 1, s, n);
#endif
}
inline static void ggml_v3_vec_sum_f32_ggf(const int n, ggml_v3_float * s, const float * x) {
ggml_v3_float sum = 0.0;
for (int i = 0; i < n; ++i) {
sum += (ggml_v3_float)x[i];
}
*s = sum;
}
inline static void ggml_v3_vec_sum_f16_ggf(const int n, float * s, const ggml_v3_fp16_t * x) {
float sum = 0.0f;
for (int i = 0; i < n; ++i) {
sum += GGML_V3_FP16_TO_FP32(x[i]);
}
*s = sum;
}
inline static void ggml_v3_vec_max_f32(const int n, float * s, const float * x) {
#ifndef GGML_USE_ACCELERATE
float max = -INFINITY;
for (int i = 0; i < n; ++i) {
max = MAX(max, x[i]);
}
*s = max;
#else
vDSP_maxv(x, 1, s, n);
#endif
}
inline static void ggml_v3_vec_norm_inv_f32(const int n, float * s, const float * x) {
ggml_v3_vec_norm_f32(n, s, x);
*s = 1.f/(*s);
}
inline static void ggml_v3_vec_argmax_f32(const int n, int * s, const float * x) {
float max = -INFINITY;
int idx = 0;
for (int i = 0; i < n; ++i) {
max = MAX(max, x[i]);
if (max == x[i]) { idx = i; }
}
*s = idx;
}
//
// data types
//
static const char * GGML_V3_OP_NAME[GGML_V3_OP_COUNT] = {
"NONE",
"DUP",
"ADD",
"ADD1",
"ACC",
"SUB",
"MUL",
"DIV",
"SQR",
"SQRT",
"LOG",
"SUM",
"SUM_ROWS",
"MEAN",
"ARGMAX",
"REPEAT",
"REPEAT_BACK",
"CONCAT",
"SILU_BACK",
"NORM",
"RMS_NORM",
"RMS_NORM_BACK",
"GROUP_NORM",
"MUL_MAT",
"MUL_MAT_ID",
"OUT_PROD",
"SCALE",
"SET",
"CPY",
"CONT",
"RESHAPE",
"VIEW",
"PERMUTE",
"TRANSPOSE",
"GET_ROWS",
"GET_ROWS_BACK",
"DIAG",
"DIAG_MASK_INF",
"DIAG_MASK_ZERO",
"SOFT_MAX",
"SOFT_MAX_BACK",
"ROPE",
"ROPE_BACK",
"ALIBI",
"CLAMP",
"CONV_TRANSPOSE_1D",
"IM2COL",
"CONV_TRANSPOSE_2D",
"POOL_1D",
"POOL_2D",
"UPSCALE",
"PAD",
"ARGSORT",
"LEAKY_RELU",
"FLASH_ATTN",
"FLASH_FF",
"FLASH_ATTN_BACK",
"WIN_PART",
"WIN_UNPART",
"GET_REL_POS",
"ADD_REL_POS",
"UNARY",
"MAP_UNARY",
"MAP_BINARY",
"MAP_CUSTOM1_F32",
"MAP_CUSTOM2_F32",
"MAP_CUSTOM3_F32",
"MAP_CUSTOM1",
"MAP_CUSTOM2",
"MAP_CUSTOM3",
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_V3_OP_COUNT == 72, "GGML_V3_OP_COUNT != 72");
static const char * GGML_V3_OP_SYMBOL[GGML_V3_OP_COUNT] = {
"none",
"x",
"x+y",
"x+y",
"view(x,nb,offset)+=y->x",
"x-y",
"x*y",
"x/y",
"x^2",
"√x",
"log(x)",
"Σx",
"Σx_k",
"Σx/n",
"argmax(x)",
"repeat(x)",
"repeat_back(x)",
"concat(x, y)",
"silu_back(x)",
"norm(x)",
"rms_norm(x)",
"rms_norm_back(x)",
"group_norm(x)",
"X*Y",
"X[i]*Y",
"X*Y",
"x*v",
"y-\\>view(x)",
"x-\\>y",
"cont(x)",
"reshape(x)",
"view(x)",
"permute(x)",
"transpose(x)",
"get_rows(x)",
"get_rows_back(x)",
"diag(x)",
"diag_mask_inf(x)",
"diag_mask_zero(x)",
"soft_max(x)",
"soft_max_back(x)",
"rope(x)",
"rope_back(x)",
"alibi(x)",
"clamp(x)",
"conv_transpose_1d(x)",
"im2col(x)",
"conv_transpose_2d(x)",
"pool_1d(x)",
"pool_2d(x)",
"upscale(x)",
"pad(x)",
"argsort(x)",
"leaky_relu(x)",
"flash_attn(x)",
"flash_ff(x)",
"flash_attn_back(x)",
"win_part(x)",
"win_unpart(x)",
"get_rel_pos(x)",
"add_rel_pos(x)",
"unary(x)",
"f(x)",
"f(x,y)",
"custom_f32(x)",
"custom_f32(x,y)",
"custom_f32(x,y,z)",
"custom(x)",
"custom(x,y)",
"custom(x,y,z)",
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_V3_OP_COUNT == 72, "GGML_V3_OP_COUNT != 72");
static_assert(GGML_V3_OP_POOL_COUNT == 2, "GGML_V3_OP_POOL_COUNT != 2");
static const char * GGML_V3_UNARY_OP_NAME[GGML_V3_UNARY_OP_COUNT] = {
"ABS",
"SGN",
"NEG",
"STEP",
"TANH",
"ELU",
"RELU",
"GELU",
"GELU_QUICK",
"SILU",
};
static_assert(GGML_V3_UNARY_OP_COUNT == 10, "GGML_V3_UNARY_OP_COUNT != 10");
static_assert(sizeof(struct ggml_v3_object)%GGML_V3_MEM_ALIGN == 0, "ggml_v3_object size must be a multiple of GGML_V3_MEM_ALIGN");
static_assert(sizeof(struct ggml_v3_tensor)%GGML_V3_MEM_ALIGN == 0, "ggml_v3_tensor size must be a multiple of GGML_V3_MEM_ALIGN");
// WARN:
// Mis-configuration can lead to problem that's hard to reason about:
// * At best it crash or talks nosense.
// * At worst it talks slightly difference but hard to perceive.
//
// An op has to enable INIT or FINALIZE when any of it's branch needs that pass.
// Take care about compile options (e.g., GGML_USE_xxx).
static bool GGML_V3_OP_HAS_INIT [GGML_V3_OP_COUNT] = { 0 };
static bool GGML_V3_OP_HAS_FINALIZE[GGML_V3_OP_COUNT] = { 0 };
static void ggml_v3_setup_op_has_task_pass(void) {
{ // INIT
bool * p = GGML_V3_OP_HAS_INIT;
p[GGML_V3_OP_ACC ] = true;
p[GGML_V3_OP_MUL_MAT ] = true;
p[GGML_V3_OP_MUL_MAT_ID ] = true;
p[GGML_V3_OP_OUT_PROD ] = true;
p[GGML_V3_OP_SET ] = true;
p[GGML_V3_OP_GET_ROWS_BACK ] = true;
p[GGML_V3_OP_DIAG_MASK_INF ] = true;
p[GGML_V3_OP_DIAG_MASK_ZERO ] = true;
p[GGML_V3_OP_CONV_TRANSPOSE_1D ] = true;
p[GGML_V3_OP_CONV_TRANSPOSE_2D ] = true;
p[GGML_V3_OP_FLASH_ATTN_BACK ] = true;
p[GGML_V3_OP_CROSS_ENTROPY_LOSS ] = true;
p[GGML_V3_OP_ADD_REL_POS ] = true;
}
{ // FINALIZE
bool * p = GGML_V3_OP_HAS_FINALIZE;
p[GGML_V3_OP_CROSS_ENTROPY_LOSS ] = true;
}
}
//
// ggml context
//
struct ggml_v3_context {
size_t mem_size;
void * mem_buffer;
bool mem_buffer_owned;
bool no_alloc;
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
int n_objects;
struct ggml_v3_object * objects_begin;
struct ggml_v3_object * objects_end;
struct ggml_v3_scratch scratch;
struct ggml_v3_scratch scratch_save;
};
struct ggml_v3_context_container {
bool used;
struct ggml_v3_context context;
};
//
// NUMA support
//
#define GGML_V3_NUMA_MAX_NODES 8
#define GGML_V3_NUMA_MAX_CPUS 512
struct ggml_v3_numa_node {
uint32_t cpus[GGML_V3_NUMA_MAX_CPUS]; // hardware threads on this node
uint32_t n_cpus;
};
struct ggml_v3_numa_nodes {
struct ggml_v3_numa_node nodes[GGML_V3_NUMA_MAX_NODES];
uint32_t n_nodes;
uint32_t total_cpus; // hardware threads on system
};
//
// ggml state
//
struct ggml_v3_state {
struct ggml_v3_context_container contexts[GGML_V3_MAX_CONTEXTS];
struct ggml_v3_numa_nodes numa;
};
// global state
static struct ggml_v3_state g_state;
static atomic_int g_state_barrier = 0;
// barrier via spin lock
inline static void ggml_v3_critical_section_start(void) {
int processing = atomic_fetch_add(&g_state_barrier, 1);
while (processing > 0) {
// wait for other threads to finish
atomic_fetch_sub(&g_state_barrier, 1);
sched_yield(); // TODO: reconsider this
processing = atomic_fetch_add(&g_state_barrier, 1);
}
}
// TODO: make this somehow automatically executed
// some sort of "sentry" mechanism
inline static void ggml_v3_critical_section_end(void) {
atomic_fetch_sub(&g_state_barrier, 1);
}
void ggml_v3_numa_init(void) {
if (g_state.numa.n_nodes > 0) {
fprintf(stderr, "ggml_v3_numa_init: NUMA already initialized\n");
return;
}
#ifdef __linux__
struct stat st;
char path[256];
int rv;
// enumerate nodes
while (g_state.numa.n_nodes < GGML_V3_NUMA_MAX_NODES) {
rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
GGML_V3_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
if (stat(path, &st) != 0) { break; }
++g_state.numa.n_nodes;
}
// enumerate CPUs
while (g_state.numa.total_cpus < GGML_V3_NUMA_MAX_CPUS) {
rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
GGML_V3_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
if (stat(path, &st) != 0) { break; }
++g_state.numa.total_cpus;
}
GGML_V3_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) {
g_state.numa.n_nodes = 0;
return;
}
for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
struct ggml_v3_numa_node * node = &g_state.numa.nodes[n];
GGML_V3_PRINT_DEBUG("CPUs on node %u:", n);
node->n_cpus = 0;
for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
GGML_V3_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
if (stat(path, &st) == 0) {
node->cpus[node->n_cpus++] = c;
GGML_V3_PRINT_DEBUG(" %u", c);
}
}
GGML_V3_PRINT_DEBUG("\n");
}
if (ggml_v3_is_numa()) {
FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
if (fptr != NULL) {
char buf[42];
if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
GGML_V3_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
}
fclose(fptr);
}
}
#else
// TODO
#endif
}
bool ggml_v3_is_numa(void) {
return g_state.numa.n_nodes > 1;
}
////////////////////////////////////////////////////////////////////////////////
void ggml_v3_print_object(const struct ggml_v3_object * obj) {
GGML_V3_PRINT(" - ggml_v3_object: type = %d, offset = %zu, size = %zu, next = %p\n",
obj->type, obj->offs, obj->size, (const void *) obj->next);
}
void ggml_v3_print_objects(const struct ggml_v3_context * ctx) {
struct ggml_v3_object * obj = ctx->objects_begin;
GGML_V3_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx);
while (obj != NULL) {
ggml_v3_print_object(obj);
obj = obj->next;
}
GGML_V3_PRINT("%s: --- end ---\n", __func__);
}
int64_t ggml_v3_nelements(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
}
int64_t ggml_v3_nrows(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
}
size_t ggml_v3_nbytes(const struct ggml_v3_tensor * tensor) {
size_t nbytes;
size_t blck_size = ggml_v3_blck_size(tensor->type);
if (blck_size == 1) {
nbytes = ggml_v3_type_size(tensor->type);
for (int i = 0; i < GGML_V3_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
}
else {
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
for (int i = 1; i < GGML_V3_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
}
return nbytes;
}
size_t ggml_v3_nbytes_pad(const struct ggml_v3_tensor * tensor) {
return GGML_V3_PAD(ggml_v3_nbytes(tensor), GGML_V3_MEM_ALIGN);
}
int ggml_v3_blck_size(enum ggml_v3_type type) {
return type_traits[type].blck_size;
}
size_t ggml_v3_type_size(enum ggml_v3_type type) {
return type_traits[type].type_size;
}
size_t ggml_v3_row_size(enum ggml_v3_type type, int64_t ne) {
assert(ne % ggml_v3_blck_size(type) == 0);
return ggml_v3_type_size(type)*ne/ggml_v3_blck_size(type);
}
double ggml_v3_type_sizef(enum ggml_v3_type type) {
return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
}
const char * ggml_v3_type_name(enum ggml_v3_type type) {
return type_traits[type].type_name;
}
bool ggml_v3_is_quantized(enum ggml_v3_type type) {
return type_traits[type].is_quantized;
}
const char * ggml_v3_op_name(enum ggml_v3_op op) {
return GGML_V3_OP_NAME[op];
}
const char * ggml_v3_op_symbol(enum ggml_v3_op op) {
return GGML_V3_OP_SYMBOL[op];
}
const char * ggml_v3_unary_op_name(enum ggml_v3_unary_op op) {
return GGML_V3_UNARY_OP_NAME[op];
}
const char * ggml_v3_op_desc(const struct ggml_v3_tensor * t) {
if (t->op == GGML_V3_OP_UNARY) {
enum ggml_v3_unary_op uop = ggml_v3_get_unary_op(t);
return ggml_v3_unary_op_name(uop);
}
else {
return ggml_v3_op_name(t->op);
}
}
size_t ggml_v3_element_size(const struct ggml_v3_tensor * tensor) {
return ggml_v3_type_size(tensor->type);
}
bool ggml_v3_is_scalar(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
bool ggml_v3_is_vector(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
bool ggml_v3_is_matrix(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
bool ggml_v3_is_3d(const struct ggml_v3_tensor * tensor) {
return tensor->ne[3] == 1;
}
int ggml_v3_n_dims(const struct ggml_v3_tensor * tensor) {
for (int i = GGML_V3_MAX_DIMS - 1; i >= 1; --i) {
if (tensor->ne[i] > 1) {
return i + 1;
}
}
return 1;
}
static inline bool ggml_v3_can_mul_mat(const struct ggml_v3_tensor * t0, const struct ggml_v3_tensor * t1) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return (t0->ne[0] == t1->ne[0]) &&
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
(t1->ne[3]%t0->ne[3] == 0);
}
static inline bool ggml_v3_can_out_prod(const struct ggml_v3_tensor * t0, const struct ggml_v3_tensor * t1) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return (t0->ne[1] == t1->ne[1]) &&
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
(t1->ne[3]%t0->ne[3] == 0);
}
enum ggml_v3_type ggml_v3_ftype_to_ggml_v3_type(enum ggml_v3_ftype ftype) {
enum ggml_v3_type wtype = GGML_V3_TYPE_COUNT;
switch (ftype) {
case GGML_V3_FTYPE_ALL_F32: wtype = GGML_V3_TYPE_F32; break;
case GGML_V3_FTYPE_MOSTLY_F16: wtype = GGML_V3_TYPE_F16; break;
case GGML_V3_FTYPE_MOSTLY_Q4_0: wtype = GGML_V3_TYPE_Q4_0; break;
case GGML_V3_FTYPE_MOSTLY_Q4_1: wtype = GGML_V3_TYPE_Q4_1; break;
case GGML_V3_FTYPE_MOSTLY_Q5_0: wtype = GGML_V3_TYPE_Q5_0; break;
case GGML_V3_FTYPE_MOSTLY_Q5_1: wtype = GGML_V3_TYPE_Q5_1; break;
case GGML_V3_FTYPE_MOSTLY_Q8_0: wtype = GGML_V3_TYPE_Q8_0; break;
case GGML_V3_FTYPE_MOSTLY_Q2_K: wtype = GGML_V3_TYPE_Q2_K; break;
case GGML_V3_FTYPE_MOSTLY_Q3_K: wtype = GGML_V3_TYPE_Q3_K; break;
case GGML_V3_FTYPE_MOSTLY_Q4_K: wtype = GGML_V3_TYPE_Q4_K; break;
case GGML_V3_FTYPE_MOSTLY_Q5_K: wtype = GGML_V3_TYPE_Q5_K; break;
case GGML_V3_FTYPE_MOSTLY_Q6_K: wtype = GGML_V3_TYPE_Q6_K; break;
case GGML_V3_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_V3_TYPE_IQ2_XXS; break;
case GGML_V3_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_V3_TYPE_IQ2_XS; break;
case GGML_V3_FTYPE_UNKNOWN: wtype = GGML_V3_TYPE_COUNT; break;
case GGML_V3_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_V3_TYPE_COUNT; break;
}
GGML_V3_ASSERT(wtype != GGML_V3_TYPE_COUNT);
return wtype;
}
size_t ggml_v3_tensor_overhead(void) {
return GGML_V3_OBJECT_SIZE + GGML_V3_TENSOR_SIZE;
}
bool ggml_v3_is_transposed(const struct ggml_v3_tensor * tensor) {
return tensor->nb[0] > tensor->nb[1];
}
bool ggml_v3_is_contiguous(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return
tensor->nb[0] == ggml_v3_type_size(tensor->type) &&
tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_v3_blck_size(tensor->type) &&
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
static inline bool ggml_v3_is_contiguous_except_dim_1(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return
tensor->nb[0] == ggml_v3_type_size(tensor->type) &&
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
bool ggml_v3_is_permuted(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
}
static inline bool ggml_v3_is_padded_1d(const struct ggml_v3_tensor * tensor) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return
tensor->nb[0] == ggml_v3_type_size(tensor->type) &&
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
bool ggml_v3_are_same_shape(const struct ggml_v3_tensor * t0, const struct ggml_v3_tensor * t1) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return
(t0->ne[0] == t1->ne[0] ) &&
(t0->ne[1] == t1->ne[1] ) &&
(t0->ne[2] == t1->ne[2] ) &&
(t0->ne[3] == t1->ne[3] );
}
// check if t1 can be represented as a repeatition of t0
static inline bool ggml_v3_can_repeat(const struct ggml_v3_tensor * t0, const struct ggml_v3_tensor * t1) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return
(t1->ne[0]%t0->ne[0] == 0) &&
(t1->ne[1]%t0->ne[1] == 0) &&
(t1->ne[2]%t0->ne[2] == 0) &&
(t1->ne[3]%t0->ne[3] == 0);
}
static inline bool ggml_v3_can_repeat_rows(const struct ggml_v3_tensor * t0, const struct ggml_v3_tensor * t1) {
static_assert(GGML_V3_MAX_DIMS == 4, "GGML_V3_MAX_DIMS is not 4 - update this function");
return (t0->ne[0] == t1->ne[0]) && ggml_v3_can_repeat(t0, t1);
}
static inline int ggml_v3_up32(int n) {
return (n + 31) & ~31;
}
//static inline int ggml_v3_up64(int n) {
// return (n + 63) & ~63;
//}
static inline int ggml_v3_up(int n, int m) {
// assert m is a power of 2
GGML_V3_ASSERT((m & (m - 1)) == 0);
return (n + m - 1) & ~(m - 1);
}
// assert that pointer is aligned to GGML_V3_MEM_ALIGN
#define ggml_v3_assert_aligned(ptr) \
GGML_V3_ASSERT(((uintptr_t) (ptr))%GGML_V3_MEM_ALIGN == 0)
////////////////////////////////////////////////////////////////////////////////
struct ggml_v3_context * ggml_v3_init(struct ggml_v3_init_params params) {
// make this function thread safe
ggml_v3_critical_section_start();
static bool is_first_call = true;
if (is_first_call) {
// initialize time system (required on Windows)
ggml_v3_time_init();
// initialize GELU, Quick GELU, SILU and EXP F32 tables
{
const uint64_t t_start = ggml_v3_time_us(); UNUSED(t_start);
ggml_v3_fp16_t ii;
for (int i = 0; i < (1 << 16); ++i) {
uint16_t ui = i;
memcpy(&ii, &ui, sizeof(ii));
const float f = ggml_v3_table_f32_f16[i] = GGML_V3_COMPUTE_FP16_TO_FP32(ii);
ggml_v3_table_gelu_f16[i] = GGML_V3_FP32_TO_FP16(ggml_v3_gelu_f32(f));
ggml_v3_table_gelu_quick_f16[i] = GGML_V3_FP32_TO_FP16(ggml_v3_gelu_quick_f32(f));
ggml_v3_table_silu_f16[i] = GGML_V3_FP32_TO_FP16(ggml_v3_silu_f32(f));
ggml_v3_table_exp_f16[i] = GGML_V3_FP32_TO_FP16(expf(f));
}
const uint64_t t_end = ggml_v3_time_us(); UNUSED(t_end);
GGML_V3_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
}
// initialize g_state
{
const uint64_t t_start = ggml_v3_time_us(); UNUSED(t_start);
g_state = (struct ggml_v3_state) {
/*.contexts =*/ { { 0 } },
/*.numa =*/ {
.n_nodes = 0,
.total_cpus = 0,
},
};
for (int i = 0; i < GGML_V3_MAX_CONTEXTS; ++i) {
g_state.contexts[i].used = false;
}
const uint64_t t_end = ggml_v3_time_us(); UNUSED(t_end);
GGML_V3_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
}
#if defined(GGML_USE_CUDA)
ggml_v3_init_cublas();
#elif defined(GGML_USE_CLBLAST)
ggml_v3_cl_init();
#endif
ggml_v3_setup_op_has_task_pass();
is_first_call = false;
}
// find non-used context in g_state
struct ggml_v3_context * ctx = NULL;
for (int i = 0; i < GGML_V3_MAX_CONTEXTS; i++) {
if (!g_state.contexts[i].used) {
g_state.contexts[i].used = true;
ctx = &g_state.contexts[i].context;
GGML_V3_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
break;
}
}
if (ctx == NULL) {
GGML_V3_PRINT_DEBUG("%s: no unused context found\n", __func__);
ggml_v3_critical_section_end();
return NULL;
}
// allow to call ggml_v3_init with 0 size
if (params.mem_size == 0) {
params.mem_size = GGML_V3_MEM_ALIGN;
}
const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_V3_PAD(params.mem_size, GGML_V3_MEM_ALIGN);
*ctx = (struct ggml_v3_context) {
/*.mem_size =*/ mem_size,
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_V3_ALIGNED_MALLOC(mem_size),
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
/*.no_alloc =*/ params.no_alloc,
/*.no_alloc_save =*/ params.no_alloc,
/*.n_objects =*/ 0,
/*.objects_begin =*/ NULL,
/*.objects_end =*/ NULL,
/*.scratch =*/ { 0, 0, NULL, },
/*.scratch_save =*/ { 0, 0, NULL, },
};
GGML_V3_ASSERT(ctx->mem_buffer != NULL);
ggml_v3_assert_aligned(ctx->mem_buffer);
GGML_V3_PRINT_DEBUG("%s: context initialized\n", __func__);
ggml_v3_critical_section_end();
return ctx;
}
void ggml_v3_free(struct ggml_v3_context * ctx) {
// make this function thread safe
ggml_v3_critical_section_start();
bool found = false;
for (int i = 0; i < GGML_V3_MAX_CONTEXTS; i++) {
if (&g_state.contexts[i].context == ctx) {
g_state.contexts[i].used = false;
GGML_V3_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n",
__func__, i, ggml_v3_used_mem(ctx));
if (ctx->mem_buffer_owned) {
GGML_V3_ALIGNED_FREE(ctx->mem_buffer);
}
found = true;
break;
}
}
if (!found) {
GGML_V3_PRINT_DEBUG("%s: context not found\n", __func__);
}
ggml_v3_critical_section_end();
}
size_t ggml_v3_used_mem(const struct ggml_v3_context * ctx) {
return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
}
size_t ggml_v3_set_scratch(struct ggml_v3_context * ctx, struct ggml_v3_scratch scratch) {
const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0;
ctx->scratch = scratch;
return result;
}
bool ggml_v3_get_no_alloc(struct ggml_v3_context * ctx) {
return ctx->no_alloc;
}
void ggml_v3_set_no_alloc(struct ggml_v3_context * ctx, bool no_alloc) {
ctx->no_alloc = no_alloc;
}
void * ggml_v3_get_mem_buffer(const struct ggml_v3_context * ctx) {
return ctx->mem_buffer;
}
size_t ggml_v3_get_mem_size(const struct ggml_v3_context * ctx) {
return ctx->mem_size;
}
size_t ggml_v3_get_max_tensor_size(const struct ggml_v3_context * ctx) {
size_t max_size = 0;
for (struct ggml_v3_tensor * tensor = ggml_v3_get_first_tensor(ctx); tensor != NULL; tensor = ggml_v3_get_next_tensor(ctx, tensor)) {
max_size = MAX(max_size, ggml_v3_nbytes(tensor));
}
return max_size;
}
// IMPORTANT:
// when creating "opt" tensors, always save and load the scratch buffer
// this is an error prone process, but it is necessary to support inplace
// operators when using scratch buffers
// TODO: implement a better way
static void ggml_v3_scratch_save(struct ggml_v3_context * ctx) {
// this is needed to allow opt tensors to store their data
// TODO: again, need to find a better way
ctx->no_alloc_save = ctx->no_alloc;
ctx->no_alloc = false;
ctx->scratch_save = ctx->scratch;
ctx->scratch.data = NULL;
}
static void ggml_v3_scratch_load(struct ggml_v3_context * ctx) {
ctx->no_alloc = ctx->no_alloc_save;
ctx->scratch = ctx->scratch_save;
}
////////////////////////////////////////////////////////////////////////////////
static struct ggml_v3_object * ggml_v3_new_object(struct ggml_v3_context * ctx, enum ggml_v3_object_type type, size_t size) {
// always insert objects at the end of the context's memory pool
struct ggml_v3_object * obj_cur = ctx->objects_end;
const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
const size_t cur_end = cur_offs + cur_size;
// align to GGML_V3_MEM_ALIGN
size_t size_needed = GGML_V3_PAD(size, GGML_V3_MEM_ALIGN);
char * const mem_buffer = ctx->mem_buffer;
struct ggml_v3_object * const obj_new = (struct ggml_v3_object *)(mem_buffer + cur_end);
if (cur_end + size_needed + GGML_V3_OBJECT_SIZE > ctx->mem_size) {
GGML_V3_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
__func__, cur_end + size_needed, ctx->mem_size);
assert(false);
return NULL;
}
*obj_new = (struct ggml_v3_object) {
.offs = cur_end + GGML_V3_OBJECT_SIZE,
.size = size_needed,
.next = NULL,
.type = type,
};
ggml_v3_assert_aligned(mem_buffer + obj_new->offs);
if (obj_cur != NULL) {
obj_cur->next = obj_new;
} else {
// this is the first object in this context
ctx->objects_begin = obj_new;
}
ctx->objects_end = obj_new;
//printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
return obj_new;
}
static struct ggml_v3_tensor * ggml_v3_new_tensor_impl(
struct ggml_v3_context * ctx,
enum ggml_v3_type type,
int n_dims,
const int64_t * ne,
struct ggml_v3_tensor * view_src,
size_t view_offs) {
assert(n_dims >= 1 && n_dims <= GGML_V3_MAX_DIMS);
// find the base tensor and absolute offset
if (view_src != NULL && view_src->view_src != NULL) {
view_offs += view_src->view_offs;
view_src = view_src->view_src;
}
size_t data_size = ggml_v3_row_size(type, ne[0]);
for (int i = 1; i < n_dims; i++) {
data_size *= ne[i];
}
GGML_V3_ASSERT(view_src == NULL || data_size + view_offs <= ggml_v3_nbytes(view_src));
void * data = view_src != NULL ? view_src->data : NULL;
if (data != NULL) {
data = (char *) data + view_offs;
}
size_t obj_alloc_size = 0;
if (view_src == NULL && !ctx->no_alloc) {
if (ctx->scratch.data != NULL) {
// allocate tensor data in the scratch buffer
if (ctx->scratch.offs + data_size > ctx->scratch.size) {
GGML_V3_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
__func__, ctx->scratch.offs + data_size, ctx->scratch.size);
assert(false);
return NULL;
}
data = (char * const) ctx->scratch.data + ctx->scratch.offs;
ctx->scratch.offs += data_size;
} else {
// allocate tensor data in the context's memory pool
obj_alloc_size = data_size;
}
}
struct ggml_v3_object * const obj_new = ggml_v3_new_object(ctx, GGML_V3_OBJECT_TENSOR, GGML_V3_TENSOR_SIZE + obj_alloc_size);
// TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
struct ggml_v3_tensor * const result = (struct ggml_v3_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
*result = (struct ggml_v3_tensor) {
/*.type =*/ type,
/*.backend =*/ GGML_V3_BACKEND_CPU,
/*.buffer =*/ NULL,
/*.ne =*/ { 1, 1, 1, 1 },
/*.nb =*/ { 0, 0, 0, 0 },
/*.op =*/ GGML_V3_OP_NONE,
/*.op_params =*/ { 0 },
/*.is_param =*/ false,
/*.grad =*/ NULL,
/*.src =*/ { NULL },
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
/*.view_src =*/ view_src,
/*.view_offs =*/ view_offs,
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
/*.name =*/ { 0 },
/*.extra =*/ NULL,
/*.padding =*/ { 0 },
};
// TODO: this should not be needed as long as we don't rely on aligned SIMD loads
//ggml_v3_assert_aligned(result->data);
for (int i = 0; i < n_dims; i++) {
result->ne[i] = ne[i];
}
result->nb[0] = ggml_v3_type_size(type);
result->nb[1] = result->nb[0]*(result->ne[0]/ggml_v3_blck_size(type));
for (int i = 2; i < GGML_V3_MAX_DIMS; i++) {
result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
}
ctx->n_objects++;
return result;
}
struct ggml_v3_tensor * ggml_v3_new_tensor(
struct ggml_v3_context * ctx,
enum ggml_v3_type type,
int n_dims,
const int64_t * ne) {
return ggml_v3_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0);
}
struct ggml_v3_tensor * ggml_v3_new_tensor_1d(
struct ggml_v3_context * ctx,
enum ggml_v3_type type,
int64_t ne0) {
return ggml_v3_new_tensor(ctx, type, 1, &ne0);
}
struct ggml_v3_tensor * ggml_v3_new_tensor_2d(
struct ggml_v3_context * ctx,
enum ggml_v3_type type,
int64_t ne0,
int64_t ne1) {
const int64_t ne[2] = { ne0, ne1 };
return ggml_v3_new_tensor(ctx, type, 2, ne);
}
struct ggml_v3_tensor * ggml_v3_new_tensor_3d(
struct ggml_v3_context * ctx,
enum ggml_v3_type type,
int64_t ne0,
int64_t ne1,
int64_t ne2) {
const int64_t ne[3] = { ne0, ne1, ne2 };
return ggml_v3_new_tensor(ctx, type, 3, ne);
}
struct ggml_v3_tensor * ggml_v3_new_tensor_4d(
struct ggml_v3_context * ctx,
enum ggml_v3_type type,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3) {
const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
return ggml_v3_new_tensor(ctx, type, 4, ne);
}
struct ggml_v3_tensor * ggml_v3_new_i32(struct ggml_v3_context * ctx, int32_t value) {
ggml_v3_scratch_save(ctx);
struct ggml_v3_tensor * result = ggml_v3_new_tensor_1d(ctx, GGML_V3_TYPE_I32, 1);
ggml_v3_scratch_load(ctx);
ggml_v3_set_i32(result, value);
return result;
}
struct ggml_v3_tensor * ggml_v3_new_f32(struct ggml_v3_context * ctx, float value) {
ggml_v3_scratch_save(ctx);
struct ggml_v3_tensor * result = ggml_v3_new_tensor_1d(ctx, GGML_V3_TYPE_F32, 1);
ggml_v3_scratch_load(ctx);
ggml_v3_set_f32(result, value);
return result;
}
struct ggml_v3_tensor * ggml_v3_dup_tensor(struct ggml_v3_context * ctx, const struct ggml_v3_tensor * src) {
return ggml_v3_new_tensor(ctx, src->type, GGML_V3_MAX_DIMS, src->ne);
}
static void ggml_v3_set_op_params(struct ggml_v3_tensor * tensor, const void * params, size_t params_size) {
GGML_V3_ASSERT(tensor != NULL); // silence -Warray-bounds warnings
assert(params_size <= GGML_V3_MAX_OP_PARAMS);
memcpy(tensor->op_params, params, params_size);
}
static int32_t ggml_v3_get_op_params_i32(const struct ggml_v3_tensor * tensor, uint32_t i) {
assert(i < GGML_V3_MAX_OP_PARAMS / sizeof(int32_t));
return ((const int32_t *)(tensor->op_params))[i];
}
static void ggml_v3_set_op_params_i32(struct ggml_v3_tensor * tensor, uint32_t i, int32_t value) {
assert(i < GGML_V3_MAX_OP_PARAMS / sizeof(int32_t));
((int32_t *)(tensor->op_params))[i] = value;
}
struct ggml_v3_tensor * ggml_v3_set_zero(struct ggml_v3_tensor * tensor) {
memset(tensor->data, 0, ggml_v3_nbytes(tensor));
return tensor;
}
struct ggml_v3_tensor * ggml_v3_set_i32 (struct ggml_v3_tensor * tensor, int32_t value) {
const int n = ggml_v3_nrows(tensor);
const int nc = tensor->ne[0];
const size_t n1 = tensor->nb[1];
char * const data = tensor->data;
switch (tensor->type) {
case GGML_V3_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
}
} break;
case GGML_V3_TYPE_I16:
{
assert(tensor->nb[0] == sizeof(int16_t));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
}
} break;
case GGML_V3_TYPE_I32:
{
assert(tensor->nb[0] == sizeof(int32_t));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
}
} break;
case GGML_V3_TYPE_F16:
{
assert(tensor->nb[0] == sizeof(ggml_v3_fp16_t));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_f16(nc, (ggml_v3_fp16_t *)(data + i*n1), GGML_V3_FP32_TO_FP16(value));
}
} break;
case GGML_V3_TYPE_F32:
{
assert(tensor->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_f32(nc, (float *)(data + i*n1), value);
}
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
return tensor;
}
struct ggml_v3_tensor * ggml_v3_set_f32(struct ggml_v3_tensor * tensor, float value) {
const int n = ggml_v3_nrows(tensor);
const int nc = tensor->ne[0];
const size_t n1 = tensor->nb[1];
char * const data = tensor->data;
switch (tensor->type) {
case GGML_V3_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
}
} break;
case GGML_V3_TYPE_I16:
{
assert(tensor->nb[0] == sizeof(int16_t));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
}
} break;
case GGML_V3_TYPE_I32:
{
assert(tensor->nb[0] == sizeof(int32_t));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
}
} break;
case GGML_V3_TYPE_F16:
{
assert(tensor->nb[0] == sizeof(ggml_v3_fp16_t));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_f16(nc, (ggml_v3_fp16_t *)(data + i*n1), GGML_V3_FP32_TO_FP16(value));
}
} break;
case GGML_V3_TYPE_F32:
{
assert(tensor->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_set_f32(nc, (float *)(data + i*n1), value);
}
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
return tensor;
}
void ggml_v3_unravel_index(const struct ggml_v3_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
const int64_t ne2 = tensor->ne[2];
const int64_t ne1 = tensor->ne[1];
const int64_t ne0 = tensor->ne[0];
const int64_t i3_ = (i/(ne2*ne1*ne0));
const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
if (i0) {
* i0 = i0_;
}
if (i1) {
* i1 = i1_;
}
if (i2) {
* i2 = i2_;
}
if (i3) {
* i3 = i3_;
}
}
int32_t ggml_v3_get_i32_1d(const struct ggml_v3_tensor * tensor, int i) {
if (!ggml_v3_is_contiguous(tensor)) {
int64_t id[4] = { 0, 0, 0, 0 };
ggml_v3_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
return ggml_v3_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
}
switch (tensor->type) {
case GGML_V3_TYPE_I8:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int8_t));
return ((int8_t *)(tensor->data))[i];
}
case GGML_V3_TYPE_I16:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int16_t));
return ((int16_t *)(tensor->data))[i];
}
case GGML_V3_TYPE_I32:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int32_t));
return ((int32_t *)(tensor->data))[i];
}
case GGML_V3_TYPE_F16:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(ggml_v3_fp16_t));
return GGML_V3_FP16_TO_FP32(((ggml_v3_fp16_t *)(tensor->data))[i]);
}
case GGML_V3_TYPE_F32:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(float));
return ((float *)(tensor->data))[i];
}
default:
{
GGML_V3_ASSERT(false);
}
}
return 0.0f;
}
void ggml_v3_set_i32_1d(const struct ggml_v3_tensor * tensor, int i, int32_t value) {
if (!ggml_v3_is_contiguous(tensor)) {
int64_t id[4] = { 0, 0, 0, 0 };
ggml_v3_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
ggml_v3_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
return;
}
switch (tensor->type) {
case GGML_V3_TYPE_I8:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int8_t));
((int8_t *)(tensor->data))[i] = value;
} break;
case GGML_V3_TYPE_I16:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int16_t));
((int16_t *)(tensor->data))[i] = value;
} break;
case GGML_V3_TYPE_I32:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int32_t));
((int32_t *)(tensor->data))[i] = value;
} break;
case GGML_V3_TYPE_F16:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(ggml_v3_fp16_t));
((ggml_v3_fp16_t *)(tensor->data))[i] = GGML_V3_FP32_TO_FP16(value);
} break;
case GGML_V3_TYPE_F32:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(float));
((float *)(tensor->data))[i] = value;
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
int32_t ggml_v3_get_i32_nd(const struct ggml_v3_tensor * tensor, int i0, int i1, int i2, int i3) {
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
switch (tensor->type) {
case GGML_V3_TYPE_I8:
return ((int8_t *) data)[0];
case GGML_V3_TYPE_I16:
return ((int16_t *) data)[0];
case GGML_V3_TYPE_I32:
return ((int32_t *) data)[0];
case GGML_V3_TYPE_F16:
return GGML_V3_FP16_TO_FP32(((ggml_v3_fp16_t *) data)[0]);
case GGML_V3_TYPE_F32:
return ((float *) data)[0];
default:
GGML_V3_ASSERT(false);
}
return 0.0f;
}
void ggml_v3_set_i32_nd(const struct ggml_v3_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
switch (tensor->type) {
case GGML_V3_TYPE_I8:
{
((int8_t *)(data))[0] = value;
} break;
case GGML_V3_TYPE_I16:
{
((int16_t *)(data))[0] = value;
} break;
case GGML_V3_TYPE_I32:
{
((int32_t *)(data))[0] = value;
} break;
case GGML_V3_TYPE_F16:
{
((ggml_v3_fp16_t *)(data))[0] = GGML_V3_FP32_TO_FP16(value);
} break;
case GGML_V3_TYPE_F32:
{
((float *)(data))[0] = value;
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
float ggml_v3_get_f32_1d(const struct ggml_v3_tensor * tensor, int i) {
if (!ggml_v3_is_contiguous(tensor)) {
int64_t id[4] = { 0, 0, 0, 0 };
ggml_v3_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
return ggml_v3_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
}
switch (tensor->type) {
case GGML_V3_TYPE_I8:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int8_t));
return ((int8_t *)(tensor->data))[i];
}
case GGML_V3_TYPE_I16:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int16_t));
return ((int16_t *)(tensor->data))[i];
}
case GGML_V3_TYPE_I32:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int32_t));
return ((int32_t *)(tensor->data))[i];
}
case GGML_V3_TYPE_F16:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(ggml_v3_fp16_t));
return GGML_V3_FP16_TO_FP32(((ggml_v3_fp16_t *)(tensor->data))[i]);
}
case GGML_V3_TYPE_F32:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(float));
return ((float *)(tensor->data))[i];
}
default:
{
GGML_V3_ASSERT(false);
}
}
return 0.0f;
}
void ggml_v3_set_f32_1d(const struct ggml_v3_tensor * tensor, int i, float value) {
if (!ggml_v3_is_contiguous(tensor)) {
int64_t id[4] = { 0, 0, 0, 0 };
ggml_v3_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
ggml_v3_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
return;
}
switch (tensor->type) {
case GGML_V3_TYPE_I8:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int8_t));
((int8_t *)(tensor->data))[i] = value;
} break;
case GGML_V3_TYPE_I16:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int16_t));
((int16_t *)(tensor->data))[i] = value;
} break;
case GGML_V3_TYPE_I32:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(int32_t));
((int32_t *)(tensor->data))[i] = value;
} break;
case GGML_V3_TYPE_F16:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(ggml_v3_fp16_t));
((ggml_v3_fp16_t *)(tensor->data))[i] = GGML_V3_FP32_TO_FP16(value);
} break;
case GGML_V3_TYPE_F32:
{
GGML_V3_ASSERT(tensor->nb[0] == sizeof(float));
((float *)(tensor->data))[i] = value;
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
float ggml_v3_get_f32_nd(const struct ggml_v3_tensor * tensor, int i0, int i1, int i2, int i3) {
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
switch (tensor->type) {
case GGML_V3_TYPE_I8:
return ((int8_t *) data)[0];
case GGML_V3_TYPE_I16:
return ((int16_t *) data)[0];
case GGML_V3_TYPE_I32:
return ((int32_t *) data)[0];
case GGML_V3_TYPE_F16:
return GGML_V3_FP16_TO_FP32(((ggml_v3_fp16_t *) data)[0]);
case GGML_V3_TYPE_F32:
return ((float *) data)[0];
default:
GGML_V3_ASSERT(false);
}
return 0.0f;
}
void ggml_v3_set_f32_nd(const struct ggml_v3_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
switch (tensor->type) {
case GGML_V3_TYPE_I8:
{
((int8_t *)(data))[0] = value;
} break;
case GGML_V3_TYPE_I16:
{
((int16_t *)(data))[0] = value;
} break;
case GGML_V3_TYPE_I32:
{
((int32_t *)(data))[0] = value;
} break;
case GGML_V3_TYPE_F16:
{
((ggml_v3_fp16_t *)(data))[0] = GGML_V3_FP32_TO_FP16(value);
} break;
case GGML_V3_TYPE_F32:
{
((float *)(data))[0] = value;
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
void * ggml_v3_get_data(const struct ggml_v3_tensor * tensor) {
return tensor->data;
}
float * ggml_v3_get_data_f32(const struct ggml_v3_tensor * tensor) {
assert(tensor->type == GGML_V3_TYPE_F32);
return (float *)(tensor->data);
}
enum ggml_v3_unary_op ggml_v3_get_unary_op(const struct ggml_v3_tensor * tensor) {
GGML_V3_ASSERT(tensor->op == GGML_V3_OP_UNARY);
return (enum ggml_v3_unary_op) ggml_v3_get_op_params_i32(tensor, 0);
}
const char * ggml_v3_get_name(const struct ggml_v3_tensor * tensor) {
return tensor->name;
}
struct ggml_v3_tensor * ggml_v3_set_name(struct ggml_v3_tensor * tensor, const char * name) {
strncpy(tensor->name, name, sizeof(tensor->name));
tensor->name[sizeof(tensor->name) - 1] = '\0';
return tensor;
}
struct ggml_v3_tensor * ggml_v3_format_name(struct ggml_v3_tensor * tensor, const char * fmt, ...) {
va_list args;
va_start(args, fmt);
vsnprintf(tensor->name, sizeof(tensor->name), fmt, args);
va_end(args);
return tensor;
}
struct ggml_v3_tensor * ggml_v3_view_tensor(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * src) {
struct ggml_v3_tensor * result = ggml_v3_new_tensor_impl(ctx, src->type, GGML_V3_MAX_DIMS, src->ne, src, 0);
ggml_v3_format_name(result, "%s (view)", src->name);
for (int i = 0; i < GGML_V3_MAX_DIMS; i++) {
result->nb[i] = src->nb[i];
}
return result;
}
struct ggml_v3_tensor * ggml_v3_get_first_tensor(const struct ggml_v3_context * ctx) {
struct ggml_v3_object * obj = ctx->objects_begin;
char * const mem_buffer = ctx->mem_buffer;
while (obj != NULL) {
if (obj->type == GGML_V3_OBJECT_TENSOR) {
return (struct ggml_v3_tensor *)(mem_buffer + obj->offs);
}
obj = obj->next;
}
return NULL;
}
struct ggml_v3_tensor * ggml_v3_get_next_tensor(const struct ggml_v3_context * ctx, struct ggml_v3_tensor * tensor) {
struct ggml_v3_object * obj = (struct ggml_v3_object *) ((char *)tensor - GGML_V3_OBJECT_SIZE);
obj = obj->next;
char * const mem_buffer = ctx->mem_buffer;
while (obj != NULL) {
if (obj->type == GGML_V3_OBJECT_TENSOR) {
return (struct ggml_v3_tensor *)(mem_buffer + obj->offs);
}
obj = obj->next;
}
return NULL;
}
struct ggml_v3_tensor * ggml_v3_get_tensor(struct ggml_v3_context * ctx, const char * name) {
struct ggml_v3_object * obj = ctx->objects_begin;
char * const mem_buffer = ctx->mem_buffer;
while (obj != NULL) {
if (obj->type == GGML_V3_OBJECT_TENSOR) {
struct ggml_v3_tensor * cur = (struct ggml_v3_tensor *)(mem_buffer + obj->offs);
if (strcmp(cur->name, name) == 0) {
return cur;
}
}
obj = obj->next;
}
return NULL;
}
////////////////////////////////////////////////////////////////////////////////
// ggml_v3_dup
static struct ggml_v3_tensor * ggml_v3_dup_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_DUP;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_dup(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_dup_impl(ctx, a, false);
}
struct ggml_v3_tensor * ggml_v3_dup_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_dup_impl(ctx, a, true);
}
// ggml_v3_add
static struct ggml_v3_tensor * ggml_v3_add_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_can_repeat(b, a));
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
// TODO: support backward pass for broadcasting
GGML_V3_ASSERT(ggml_v3_are_same_shape(a, b));
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_ADD;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_add(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_add_impl(ctx, a, b, false);
}
struct ggml_v3_tensor * ggml_v3_add_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_add_impl(ctx, a, b, true);
}
// ggml_v3_add_cast
static struct ggml_v3_tensor * ggml_v3_add_cast_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
enum ggml_v3_type type) {
// TODO: support less-strict constraint
// GGML_V3_ASSERT(ggml_v3_can_repeat(b, a));
GGML_V3_ASSERT(ggml_v3_can_repeat_rows(b, a));
GGML_V3_ASSERT(ggml_v3_is_quantized(a->type) || a->type == GGML_V3_TYPE_F16); // currently only supported for quantized input and f16
bool is_node = false;
if (a->grad || b->grad) {
// TODO: support backward pass for broadcasting
GGML_V3_ASSERT(ggml_v3_are_same_shape(a, b));
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, type, GGML_V3_MAX_DIMS, a->ne);
result->op = GGML_V3_OP_ADD;
result->grad = is_node ? ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, GGML_V3_MAX_DIMS, a->ne) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_add_cast(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
enum ggml_v3_type type) {
return ggml_v3_add_cast_impl(ctx, a, b, type);
}
// ggml_v3_add1
static struct ggml_v3_tensor * ggml_v3_add1_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_is_scalar(b));
GGML_V3_ASSERT(ggml_v3_is_padded_1d(a));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_ADD1;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_add1(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_add1_impl(ctx, a, b, false);
}
struct ggml_v3_tensor * ggml_v3_add1_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_add1_impl(ctx, a, b, true);
}
// ggml_v3_acc
static struct ggml_v3_tensor * ggml_v3_acc_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t nb1,
size_t nb2,
size_t nb3,
size_t offset,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_nelements(b) <= ggml_v3_nelements(a));
GGML_V3_ASSERT(ggml_v3_is_contiguous(a));
GGML_V3_ASSERT(a->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT(b->type == GGML_V3_TYPE_F32);
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_ACC;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_acc(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t nb1,
size_t nb2,
size_t nb3,
size_t offset) {
return ggml_v3_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
}
struct ggml_v3_tensor * ggml_v3_acc_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t nb1,
size_t nb2,
size_t nb3,
size_t offset) {
return ggml_v3_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
}
// ggml_v3_sub
static struct ggml_v3_tensor * ggml_v3_sub_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(a, b));
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_SUB;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_sub(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_sub_impl(ctx, a, b, false);
}
struct ggml_v3_tensor * ggml_v3_sub_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_sub_impl(ctx, a, b, true);
}
// ggml_v3_mul
static struct ggml_v3_tensor * ggml_v3_mul_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_can_repeat(b, a));
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
// TODO: support backward pass for broadcasting
GGML_V3_ASSERT(ggml_v3_are_same_shape(a, b));
is_node = true;
}
if (inplace) {
GGML_V3_ASSERT(!is_node);
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_MUL;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_mul(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_mul_impl(ctx, a, b, false);
}
struct ggml_v3_tensor * ggml_v3_mul_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_mul_impl(ctx, a, b, true);
}
// ggml_v3_div
static struct ggml_v3_tensor * ggml_v3_div_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_can_repeat(b, a));
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
is_node = true;
}
if (inplace) {
GGML_V3_ASSERT(!is_node);
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_DIV;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_div(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_div_impl(ctx, a, b, false);
}
struct ggml_v3_tensor * ggml_v3_div_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_div_impl(ctx, a, b, true);
}
// ggml_v3_sqr
static struct ggml_v3_tensor * ggml_v3_sqr_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_SQR;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_sqr(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_sqr_impl(ctx, a, false);
}
struct ggml_v3_tensor * ggml_v3_sqr_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_sqr_impl(ctx, a, true);
}
// ggml_v3_sqrt
static struct ggml_v3_tensor * ggml_v3_sqrt_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_SQRT;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_sqrt(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_sqrt_impl(ctx, a, false);
}
struct ggml_v3_tensor * ggml_v3_sqrt_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_sqrt_impl(ctx, a, true);
}
// ggml_v3_log
static struct ggml_v3_tensor * ggml_v3_log_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_LOG;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_log(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_log_impl(ctx, a, false);
}
struct ggml_v3_tensor * ggml_v3_log_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_log_impl(ctx, a, true);
}
// ggml_v3_sum
struct ggml_v3_tensor * ggml_v3_sum(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_1d(ctx, a->type, 1);
result->op = GGML_V3_OP_SUM;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_sum_rows
struct ggml_v3_tensor * ggml_v3_sum_rows(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
int64_t ne[GGML_V3_MAX_DIMS] = { 1 };
for (int i = 1; i < GGML_V3_MAX_DIMS; ++i) {
ne[i] = a->ne[i];
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, a->type, GGML_V3_MAX_DIMS, ne);
result->op = GGML_V3_OP_SUM_ROWS;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_mean
struct ggml_v3_tensor * ggml_v3_mean(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement
is_node = true;
}
int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 4, ne);
result->op = GGML_V3_OP_MEAN;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_argmax
struct ggml_v3_tensor * ggml_v3_argmax(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
GGML_V3_ASSERT(ggml_v3_is_matrix(a));
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false);
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_1d(ctx, GGML_V3_TYPE_I32, a->ne[1]);
result->op = GGML_V3_OP_ARGMAX;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_repeat
struct ggml_v3_tensor * ggml_v3_repeat(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(ggml_v3_can_repeat(a, b));
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, a->type, GGML_V3_MAX_DIMS, b->ne);
result->op = GGML_V3_OP_REPEAT;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_repeat_back
struct ggml_v3_tensor * ggml_v3_repeat_back(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(ggml_v3_can_repeat(b, a));
bool is_node = false;
if (a->grad) {
is_node = true;
}
if (ggml_v3_are_same_shape(a, b) && !is_node) {
return a;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, a->type, GGML_V3_MAX_DIMS, b->ne);
result->op = GGML_V3_OP_REPEAT_BACK;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_concat
struct ggml_v3_tensor * ggml_v3_concat(
struct ggml_v3_context* ctx,
struct ggml_v3_tensor* a,
struct ggml_v3_tensor* b) {
GGML_V3_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
result->op = GGML_V3_OP_CONCAT;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_abs
struct ggml_v3_tensor * ggml_v3_abs(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_ABS);
}
struct ggml_v3_tensor * ggml_v3_abs_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_ABS);
}
// ggml_v3_sgn
struct ggml_v3_tensor * ggml_v3_sgn(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_SGN);
}
struct ggml_v3_tensor * ggml_v3_sgn_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_SGN);
}
// ggml_v3_neg
struct ggml_v3_tensor * ggml_v3_neg(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_NEG);
}
struct ggml_v3_tensor * ggml_v3_neg_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_NEG);
}
// ggml_v3_step
struct ggml_v3_tensor * ggml_v3_step(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_STEP);
}
struct ggml_v3_tensor * ggml_v3_step_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_STEP);
}
// ggml_v3_tanh
struct ggml_v3_tensor * ggml_v3_tanh(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_TANH);
}
struct ggml_v3_tensor * ggml_v3_tanh_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_TANH);
}
// ggml_v3_elu
struct ggml_v3_tensor * ggml_v3_elu(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_ELU);
}
struct ggml_v3_tensor * ggml_v3_elu_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_ELU);
}
// ggml_v3_relu
struct ggml_v3_tensor * ggml_v3_relu(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_RELU);
}
struct ggml_v3_tensor * ggml_v3_relu_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_RELU);
}
// ggml_v3_leaky_relu
struct ggml_v3_tensor * ggml_v3_leaky_relu(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a, float negative_slope, bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, &negative_slope, sizeof(negative_slope));
result->op = GGML_V3_OP_LEAKY_RELU;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_gelu
struct ggml_v3_tensor * ggml_v3_gelu(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_GELU);
}
struct ggml_v3_tensor * ggml_v3_gelu_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_GELU);
}
// ggml_v3_gelu_quick
struct ggml_v3_tensor * ggml_v3_gelu_quick(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_GELU_QUICK);
}
struct ggml_v3_tensor * ggml_v3_gelu_quick_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_GELU_QUICK);
}
// ggml_v3_silu
struct ggml_v3_tensor * ggml_v3_silu(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary(ctx, a, GGML_V3_UNARY_OP_SILU);
}
struct ggml_v3_tensor * ggml_v3_silu_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_unary_inplace(ctx, a, GGML_V3_UNARY_OP_SILU);
}
// ggml_v3_silu_back
struct ggml_v3_tensor * ggml_v3_silu_back(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
bool is_node = false;
if (a->grad || b->grad) {
// TODO: implement backward
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_SILU_BACK;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_norm
static struct ggml_v3_tensor * ggml_v3_norm_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float eps,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_V3_OP_NORM;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_norm(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float eps) {
return ggml_v3_norm_impl(ctx, a, eps, false);
}
struct ggml_v3_tensor * ggml_v3_norm_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float eps) {
return ggml_v3_norm_impl(ctx, a, eps, true);
}
// ggml_v3_rms_norm
static struct ggml_v3_tensor * ggml_v3_rms_norm_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float eps,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_V3_OP_RMS_NORM;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_rms_norm(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float eps) {
return ggml_v3_rms_norm_impl(ctx, a, eps, false);
}
struct ggml_v3_tensor * ggml_v3_rms_norm_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float eps) {
return ggml_v3_rms_norm_impl(ctx, a, eps, true);
}
// ggml_v3_rms_norm_back
struct ggml_v3_tensor * ggml_v3_rms_norm_back(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
float eps) {
bool is_node = false;
if (a->grad) {
// TODO: implement backward
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_V3_OP_RMS_NORM_BACK;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_group_norm
static struct ggml_v3_tensor * ggml_v3_group_norm_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_groups,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op_params[0] = n_groups;
result->op = GGML_V3_OP_GROUP_NORM;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_group_norm(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_groups) {
return ggml_v3_group_norm_impl(ctx, a, n_groups, false);
}
struct ggml_v3_tensor * ggml_v3_group_norm_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_groups) {
return ggml_v3_group_norm_impl(ctx, a, n_groups, true);
}
// ggml_v3_mul_mat
struct ggml_v3_tensor * ggml_v3_mul_mat(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(ggml_v3_can_mul_mat(a, b));
GGML_V3_ASSERT(!ggml_v3_is_transposed(a));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 4, ne);
result->op = GGML_V3_OP_MUL_MAT;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
void ggml_v3_mul_mat_set_prec(
struct ggml_v3_tensor * a,
enum ggml_v3_prec prec) {
const int32_t prec_i32 = (int32_t) prec;
ggml_v3_set_op_params_i32(a, 0, prec_i32);
}
// ggml_v3_mul_mat_id
struct ggml_v3_tensor * ggml_v3_mul_mat_id(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * const as[],
int n_as,
struct ggml_v3_tensor * ids,
int id,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(ids->type == GGML_V3_TYPE_I32);
GGML_V3_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
GGML_V3_ASSERT(ids->ne[1] == b->ne[1]);
GGML_V3_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
GGML_V3_ASSERT(n_as > 0 && n_as <= GGML_V3_MAX_SRC - 2);
GGML_V3_ASSERT(id >= 0 && id < ids->ne[0]);
bool is_node = false;
if (as[0]->grad || b->grad) {
is_node = true;
}
const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 4, ne);
ggml_v3_set_op_params_i32(result, 0, id);
ggml_v3_set_op_params_i32(result, 1, n_as);
result->op = GGML_V3_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = ids;
result->src[1] = b;
for (int i = 0; i < n_as; i++) {
struct ggml_v3_tensor * a = as[i];
GGML_V3_ASSERT(ggml_v3_are_same_shape(as[0], a));
GGML_V3_ASSERT(ggml_v3_can_mul_mat(a, b));
GGML_V3_ASSERT(!ggml_v3_is_transposed(a));
result->src[i + 2] = a;
}
return result;
}
// ggml_v3_out_prod
struct ggml_v3_tensor * ggml_v3_out_prod(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(ggml_v3_can_out_prod(a, b));
GGML_V3_ASSERT(!ggml_v3_is_transposed(a));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
// a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 4, ne);
result->op = GGML_V3_OP_OUT_PROD;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_scale
static struct ggml_v3_tensor * ggml_v3_scale_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float s,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_is_padded_1d(a));
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, &s, sizeof(s));
result->op = GGML_V3_OP_SCALE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_scale(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float s) {
return ggml_v3_scale_impl(ctx, a, s, false);
}
struct ggml_v3_tensor * ggml_v3_scale_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float s) {
return ggml_v3_scale_impl(ctx, a, s, true);
}
// ggml_v3_set
static struct ggml_v3_tensor * ggml_v3_set_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t nb1,
size_t nb2,
size_t nb3,
size_t offset,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_nelements(a) >= ggml_v3_nelements(b));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
// make a view of the destination
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_SET;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_set(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t nb1,
size_t nb2,
size_t nb3,
size_t offset) {
return ggml_v3_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
}
struct ggml_v3_tensor * ggml_v3_set_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t nb1,
size_t nb2,
size_t nb3,
size_t offset) {
return ggml_v3_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
}
struct ggml_v3_tensor * ggml_v3_set_1d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t offset) {
return ggml_v3_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
}
struct ggml_v3_tensor * ggml_v3_set_1d_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t offset) {
return ggml_v3_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
}
struct ggml_v3_tensor * ggml_v3_set_2d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t nb1,
size_t offset) {
return ggml_v3_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
}
struct ggml_v3_tensor * ggml_v3_set_2d_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
size_t nb1,
size_t offset) {
return ggml_v3_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
}
// ggml_v3_cpy
static struct ggml_v3_tensor * ggml_v3_cpy_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(ggml_v3_nelements(a) == ggml_v3_nelements(b));
bool is_node = false;
if (a->grad || b->grad) {
// inplace is false and either one have a grad
is_node = true;
}
// make a view of the destination
struct ggml_v3_tensor * result = ggml_v3_view_tensor(ctx, b);
if (strlen(b->name) > 0) {
ggml_v3_format_name(result, "%s (copy of %s)", b->name, a->name);
} else {
ggml_v3_format_name(result, "%s (copy)", a->name);
}
result->op = GGML_V3_OP_CPY;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_cpy(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_cpy_impl(ctx, a, b);
}
// ggml_v3_cont
static struct ggml_v3_tensor * ggml_v3_cont_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_dup_tensor(ctx, a);
ggml_v3_format_name(result, "%s (cont)", a->name);
result->op = GGML_V3_OP_CONT;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_cont(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_cont_impl(ctx, a);
}
// make contiguous, with new shape
GGML_V3_API struct ggml_v3_tensor * ggml_v3_cont_1d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0) {
return ggml_v3_cont_4d(ctx, a, ne0, 1, 1, 1);
}
GGML_V3_API struct ggml_v3_tensor * ggml_v3_cont_2d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1) {
return ggml_v3_cont_4d(ctx, a, ne0, ne1, 1, 1);
}
GGML_V3_API struct ggml_v3_tensor * ggml_v3_cont_3d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2) {
return ggml_v3_cont_4d(ctx, a, ne0, ne1, ne2, 1);
}
struct ggml_v3_tensor * ggml_v3_cont_4d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3) {
GGML_V3_ASSERT(ggml_v3_nelements(a) == (ne0*ne1*ne2*ne3));
bool is_node = false;
struct ggml_v3_tensor * result = ggml_v3_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
ggml_v3_format_name(result, "%s (cont)", a->name);
result->op = GGML_V3_OP_CONT;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_reshape
struct ggml_v3_tensor * ggml_v3_reshape(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(a));
// as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
GGML_V3_ASSERT(ggml_v3_nelements(a) == ggml_v3_nelements(b));
bool is_node = false;
if (a->grad) {
is_node = true;
}
if (b->grad) {
// gradient propagation is not supported
//GGML_V3_ASSERT(false);
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_impl(ctx, a->type, GGML_V3_MAX_DIMS, b->ne, a, 0);
ggml_v3_format_name(result, "%s (reshaped)", a->name);
result->op = GGML_V3_OP_RESHAPE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_reshape_1d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(a));
GGML_V3_ASSERT(ggml_v3_nelements(a) == ne0);
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[1] = { ne0 };
struct ggml_v3_tensor * result = ggml_v3_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
ggml_v3_format_name(result, "%s (reshaped)", a->name);
result->op = GGML_V3_OP_RESHAPE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_reshape_2d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(a));
GGML_V3_ASSERT(ggml_v3_nelements(a) == ne0*ne1);
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[2] = { ne0, ne1 };
struct ggml_v3_tensor * result = ggml_v3_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
ggml_v3_format_name(result, "%s (reshaped)", a->name);
result->op = GGML_V3_OP_RESHAPE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_reshape_3d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(a));
GGML_V3_ASSERT(ggml_v3_nelements(a) == ne0*ne1*ne2);
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[3] = { ne0, ne1, ne2 };
struct ggml_v3_tensor * result = ggml_v3_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
ggml_v3_format_name(result, "%s (reshaped)", a->name);
result->op = GGML_V3_OP_RESHAPE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_reshape_4d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(a));
GGML_V3_ASSERT(ggml_v3_nelements(a) == ne0*ne1*ne2*ne3);
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
struct ggml_v3_tensor * result = ggml_v3_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
ggml_v3_format_name(result, "%s (reshaped)", a->name);
result->op = GGML_V3_OP_RESHAPE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
static struct ggml_v3_tensor * ggml_v3_view_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_dims,
const int64_t * ne,
size_t offset) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
ggml_v3_format_name(result, "%s (view)", a->name);
ggml_v3_set_op_params(result, &offset, sizeof(offset));
result->op = GGML_V3_OP_VIEW;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_view_1d
struct ggml_v3_tensor * ggml_v3_view_1d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
size_t offset) {
struct ggml_v3_tensor * result = ggml_v3_view_impl(ctx, a, 1, &ne0, offset);
return result;
}
// ggml_v3_view_2d
struct ggml_v3_tensor * ggml_v3_view_2d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1,
size_t nb1,
size_t offset) {
const int64_t ne[2] = { ne0, ne1 };
struct ggml_v3_tensor * result = ggml_v3_view_impl(ctx, a, 2, ne, offset);
result->nb[1] = nb1;
result->nb[2] = result->nb[1]*ne1;
result->nb[3] = result->nb[2];
return result;
}
// ggml_v3_view_3d
struct ggml_v3_tensor * ggml_v3_view_3d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
size_t nb1,
size_t nb2,
size_t offset) {
const int64_t ne[3] = { ne0, ne1, ne2 };
struct ggml_v3_tensor * result = ggml_v3_view_impl(ctx, a, 3, ne, offset);
result->nb[1] = nb1;
result->nb[2] = nb2;
result->nb[3] = result->nb[2]*ne2;
return result;
}
// ggml_v3_view_4d
struct ggml_v3_tensor * ggml_v3_view_4d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3,
size_t nb1,
size_t nb2,
size_t nb3,
size_t offset) {
const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
struct ggml_v3_tensor * result = ggml_v3_view_impl(ctx, a, 4, ne, offset);
result->nb[1] = nb1;
result->nb[2] = nb2;
result->nb[3] = nb3;
return result;
}
// ggml_v3_permute
struct ggml_v3_tensor * ggml_v3_permute(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int axis0,
int axis1,
int axis2,
int axis3) {
GGML_V3_ASSERT(axis0 >= 0 && axis0 < GGML_V3_MAX_DIMS);
GGML_V3_ASSERT(axis1 >= 0 && axis1 < GGML_V3_MAX_DIMS);
GGML_V3_ASSERT(axis2 >= 0 && axis2 < GGML_V3_MAX_DIMS);
GGML_V3_ASSERT(axis3 >= 0 && axis3 < GGML_V3_MAX_DIMS);
GGML_V3_ASSERT(axis0 != axis1);
GGML_V3_ASSERT(axis0 != axis2);
GGML_V3_ASSERT(axis0 != axis3);
GGML_V3_ASSERT(axis1 != axis2);
GGML_V3_ASSERT(axis1 != axis3);
GGML_V3_ASSERT(axis2 != axis3);
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_view_tensor(ctx, a);
ggml_v3_format_name(result, "%s (permuted)", a->name);
int ne[GGML_V3_MAX_DIMS];
int nb[GGML_V3_MAX_DIMS];
ne[axis0] = a->ne[0];
ne[axis1] = a->ne[1];
ne[axis2] = a->ne[2];
ne[axis3] = a->ne[3];
nb[axis0] = a->nb[0];
nb[axis1] = a->nb[1];
nb[axis2] = a->nb[2];
nb[axis3] = a->nb[3];
result->ne[0] = ne[0];
result->ne[1] = ne[1];
result->ne[2] = ne[2];
result->ne[3] = ne[3];
result->nb[0] = nb[0];
result->nb[1] = nb[1];
result->nb[2] = nb[2];
result->nb[3] = nb[3];
result->op = GGML_V3_OP_PERMUTE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
int32_t params[] = { axis0, axis1, axis2, axis3 };
ggml_v3_set_op_params(result, params, sizeof(params));
return result;
}
// ggml_v3_transpose
struct ggml_v3_tensor * ggml_v3_transpose(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_view_tensor(ctx, a);
ggml_v3_format_name(result, "%s (transposed)", a->name);
result->ne[0] = a->ne[1];
result->ne[1] = a->ne[0];
result->nb[0] = a->nb[1];
result->nb[1] = a->nb[0];
result->op = GGML_V3_OP_TRANSPOSE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_get_rows
struct ggml_v3_tensor * ggml_v3_get_rows(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(a->ne[2] == b->ne[1]);
GGML_V3_ASSERT(b->ne[3] == 1);
GGML_V3_ASSERT(b->type == GGML_V3_TYPE_I32);
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
// TODO: implement non F32 return
enum ggml_v3_type type = GGML_V3_TYPE_F32;
if (a->type == GGML_V3_TYPE_I32) {
type = a->type;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
result->op = GGML_V3_OP_GET_ROWS;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_get_rows_back
struct ggml_v3_tensor * ggml_v3_get_rows_back(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
struct ggml_v3_tensor * c) {
GGML_V3_ASSERT(ggml_v3_is_matrix(a) && ggml_v3_is_vector(b) && b->type == GGML_V3_TYPE_I32);
GGML_V3_ASSERT(ggml_v3_is_matrix(c) && (a->ne[0] == c->ne[0]));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
// TODO: implement non F32 return
//struct ggml_v3_tensor * result = ggml_v3_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
struct ggml_v3_tensor * result = ggml_v3_new_tensor_2d(ctx, GGML_V3_TYPE_F32, c->ne[0], c->ne[1]);
result->op = GGML_V3_OP_GET_ROWS_BACK;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_diag
struct ggml_v3_tensor * ggml_v3_diag(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
GGML_V3_ASSERT(a->ne[1] == 1);
bool is_node = false;
if (a->grad) {
is_node = true;
}
const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, a->type, 4, ne);
result->op = GGML_V3_OP_DIAG;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_diag_mask_inf
static struct ggml_v3_tensor * ggml_v3_diag_mask_inf_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_past,
bool inplace) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
int32_t params[] = { n_past };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_DIAG_MASK_INF;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_diag_mask_inf(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_past) {
return ggml_v3_diag_mask_inf_impl(ctx, a, n_past, false);
}
struct ggml_v3_tensor * ggml_v3_diag_mask_inf_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_past) {
return ggml_v3_diag_mask_inf_impl(ctx, a, n_past, true);
}
// ggml_v3_diag_mask_zero
static struct ggml_v3_tensor * ggml_v3_diag_mask_zero_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_past,
bool inplace) {
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
int32_t params[] = { n_past };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_DIAG_MASK_ZERO;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_diag_mask_zero(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_past) {
return ggml_v3_diag_mask_zero_impl(ctx, a, n_past, false);
}
struct ggml_v3_tensor * ggml_v3_diag_mask_zero_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_past) {
return ggml_v3_diag_mask_zero_impl(ctx, a, n_past, true);
}
// ggml_v3_soft_max
static struct ggml_v3_tensor * ggml_v3_soft_max_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * mask,
float scale,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(a));
if (mask) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(mask));
GGML_V3_ASSERT(mask->ne[2] == 1);
GGML_V3_ASSERT(mask->ne[3] == 1);
GGML_V3_ASSERT(ggml_v3_can_repeat_rows(mask, a));
}
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
float params[] = { scale };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_SOFT_MAX;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = mask;
return result;
}
struct ggml_v3_tensor * ggml_v3_soft_max(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_soft_max_impl(ctx, a, NULL, 1.0f, false);
}
struct ggml_v3_tensor * ggml_v3_soft_max_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a) {
return ggml_v3_soft_max_impl(ctx, a, NULL, 1.0f, true);
}
struct ggml_v3_tensor * ggml_v3_soft_max_ext(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * mask,
float scale) {
return ggml_v3_soft_max_impl(ctx, a, mask, scale, false);
}
// ggml_v3_soft_max_back
static struct ggml_v3_tensor * ggml_v3_soft_max_back_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
bool inplace) {
bool is_node = false;
if (a->grad || b->grad) {
is_node = true; // TODO : implement backward pass
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_SOFT_MAX_BACK;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_soft_max_back(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_soft_max_back_impl(ctx, a, b, false);
}
struct ggml_v3_tensor * ggml_v3_soft_max_back_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_soft_max_back_impl(ctx, a, b, true);
}
// ggml_v3_rope
static struct ggml_v3_tensor * ggml_v3_rope_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
float xpos_base,
bool xpos_down,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_is_vector(b));
GGML_V3_ASSERT(b->type == GGML_V3_TYPE_I32);
GGML_V3_ASSERT(a->ne[2] == b->ne[0]);
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &xpos_base, sizeof(float));
memcpy(params + 12, &xpos_down, sizeof(bool));
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_ROPE;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_rope(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int n_dims,
int mode,
int n_ctx) {
return ggml_v3_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
);
}
struct ggml_v3_tensor * ggml_v3_rope_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int n_dims,
int mode,
int n_ctx) {
return ggml_v3_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
);
}
struct ggml_v3_tensor * ggml_v3_rope_custom(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
return ggml_v3_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
);
}
struct ggml_v3_tensor * ggml_v3_rope_custom_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
return ggml_v3_rope_impl(
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
);
}
struct ggml_v3_tensor * ggml_v3_rope_xpos_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int n_dims,
float base,
bool down) {
return ggml_v3_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
}
// ggml_v3_rope_back
struct ggml_v3_tensor * ggml_v3_rope_back(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
float xpos_base,
bool xpos_down) {
GGML_V3_ASSERT(ggml_v3_is_vector(b));
GGML_V3_ASSERT(b->type == GGML_V3_TYPE_I32);
GGML_V3_ASSERT(a->ne[2] == b->ne[0]);
GGML_V3_ASSERT((mode & 4) == 0 && "ggml_v3_rope_back() for ChatGLM not implemented yet");
bool is_node = false;
if (a->grad) {
is_node = false; // TODO: implement backward
}
struct ggml_v3_tensor * result = ggml_v3_dup_tensor(ctx, a);
int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &xpos_base, sizeof(float));
memcpy(params + 12, &xpos_down, sizeof(bool));
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_ROPE_BACK;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_alibi
struct ggml_v3_tensor * ggml_v3_alibi(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int n_past,
int n_head,
float bias_max) {
GGML_V3_ASSERT(n_past >= 0);
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
// TODO: when implement backward, fix this:
//struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
struct ggml_v3_tensor * result = ggml_v3_view_tensor(ctx, a);
int32_t op_params[3] = { n_past, n_head };
memcpy(op_params + 2, &bias_max, sizeof(float));
ggml_v3_set_op_params(result, op_params, sizeof(op_params));
result->op = GGML_V3_OP_ALIBI;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_clamp
struct ggml_v3_tensor * ggml_v3_clamp(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
float min,
float max) {
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
// TODO: when implement backward, fix this:
struct ggml_v3_tensor * result = ggml_v3_view_tensor(ctx, a);
float params[] = { min, max };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_CLAMP;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_conv_1d
static int64_t ggml_v3_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
}
GGML_V3_API struct ggml_v3_tensor * ggml_v3_conv_1d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int s0,
int p0,
int d0) {
struct ggml_v3_tensor * im2col = ggml_v3_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
struct ggml_v3_tensor * result =
ggml_v3_mul_mat(ctx,
ggml_v3_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
ggml_v3_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OCIC, K] => [OC, IC * K]
result = ggml_v3_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
return result;
}
// ggml_v3_conv_1d_ph
struct ggml_v3_tensor* ggml_v3_conv_1d_ph(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int s,
int d) {
return ggml_v3_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
}
// ggml_v3_conv_transpose_1d
static int64_t ggml_v3_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
}
GGML_V3_API struct ggml_v3_tensor * ggml_v3_conv_transpose_1d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int s0,
int p0,
int d0) {
GGML_V3_ASSERT(ggml_v3_is_matrix(b));
GGML_V3_ASSERT(a->ne[2] == b->ne[1]);
GGML_V3_ASSERT(a->ne[3] == 1);
GGML_V3_ASSERT(p0 == 0);
GGML_V3_ASSERT(d0 == 1);
bool is_node = false;
if (a->grad || b->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t ne[4] = {
ggml_v3_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
a->ne[1], b->ne[2], 1,
};
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 4, ne);
int32_t params[] = { s0, p0, d0 };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_CONV_TRANSPOSE_1D;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_conv_2d
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
// a: [OCIC, KH, KW]
// b: [N, IC, IH, IW]
// result: [N, OH, OW, IC*KH*KW]
struct ggml_v3_tensor * ggml_v3_im2col(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1,
bool is_2D) {
if(is_2D) {
GGML_V3_ASSERT(a->ne[2] == b->ne[2]);
} else {
GGML_V3_ASSERT(a->ne[1] == b->ne[1]);
}
bool is_node = false;
if (a->grad || b->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t OH = is_2D ? ggml_v3_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
const int64_t OW = ggml_v3_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
const int64_t ne[4] = {
is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
OW,
is_2D ? OH : b->ne[2],
is_2D ? b->ne[3] : 1,
};
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F16, 4, ne);
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_IM2COL;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// a: [OCIC, KH, KW]
// b: [N, IC, IH, IW]
// result: [N, OC, OH, OW]
struct ggml_v3_tensor * ggml_v3_conv_2d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1) {
struct ggml_v3_tensor * im2col = ggml_v3_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
struct ggml_v3_tensor * result =
ggml_v3_mul_mat(ctx,
ggml_v3_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
ggml_v3_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OCIC, KH, KW] => [OC, IC * KH * KW]
result = ggml_v3_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
return result;
}
// ggml_v3_conv_2d_sk_p0
struct ggml_v3_tensor * ggml_v3_conv_2d_sk_p0(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
}
// ggml_v3_conv_2d_s1_ph
struct ggml_v3_tensor * ggml_v3_conv_2d_s1_ph(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
return ggml_v3_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
}
// ggml_v3_conv_transpose_2d_p0
static int64_t ggml_v3_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
return (ins - 1) * s - 2 * p + ks;
}
struct ggml_v3_tensor * ggml_v3_conv_transpose_2d_p0(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
int stride) {
GGML_V3_ASSERT(a->ne[3] == b->ne[2]);
bool is_node = false;
if (a->grad || b->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t ne[4] = {
ggml_v3_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
ggml_v3_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
a->ne[2], b->ne[3],
};
struct ggml_v3_tensor* result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 4, ne);
ggml_v3_set_op_params_i32(result, 0, stride);
result->op = GGML_V3_OP_CONV_TRANSPOSE_2D;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_pool_*
static int64_t ggml_v3_calc_pool_output_size(int64_t ins, int ks, int s, float p) {
return (ins + 2 * p - ks) / s + 1;
}
// ggml_v3_pool_1d
struct ggml_v3_tensor * ggml_v3_pool_1d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
enum ggml_v3_op_pool op,
int k0,
int s0,
int p0) {
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t ne[2] = {
ggml_v3_calc_pool_output_size(a->ne[0], k0, s0, p0),
a->ne[1],
};
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 2, ne);
int32_t params[] = { op, k0, s0, p0 };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_POOL_1D;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_pool_2d
struct ggml_v3_tensor * ggml_v3_pool_2d(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
enum ggml_v3_op_pool op,
int k0,
int k1,
int s0,
int s1,
float p0,
float p1) {
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t ne[3] = {
ggml_v3_calc_pool_output_size(a->ne[0], k0, s0, p0),
ggml_v3_calc_pool_output_size(a->ne[1], k1, s1, p1),
a->ne[2],
};
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 3, ne);
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_POOL_2D;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_upscale
static struct ggml_v3_tensor * ggml_v3_upscale_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int scale_factor) {
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_4d(ctx, a->type,
a->ne[0] * scale_factor,
a->ne[1] * scale_factor,
a->ne[2], a->ne[3]);
result->op = GGML_V3_OP_UPSCALE;
result->op_params[0] = scale_factor;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_pad(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int p0, int p1, int p2, int p3) {
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_4d(ctx, a->type,
a->ne[0] + p0,
a->ne[1] + p1,
a->ne[2] + p2,
a->ne[3] + p3);
result->op = GGML_V3_OP_PAD;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_upscale(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int scale_factor) {
return ggml_v3_upscale_impl(ctx, a, scale_factor);
}
// ggml_v3_argsort
struct ggml_v3_tensor * ggml_v3_argsort(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
enum ggml_v3_sort_order order) {
bool is_node = false;
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_I32, GGML_V3_MAX_DIMS, a->ne);
ggml_v3_set_op_params_i32(result, 0, (int32_t) order);
result->op = GGML_V3_OP_ARGSORT;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_top_k
struct ggml_v3_tensor * ggml_v3_top_k(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int k) {
GGML_V3_ASSERT(a->ne[0] >= k);
struct ggml_v3_tensor * result = ggml_v3_argsort(ctx, a, GGML_V3_SORT_DESC);
result = ggml_v3_view_4d(ctx, result,
k, result->ne[1], result->ne[2], result->ne[3],
result->nb[1], result->nb[2], result->nb[3],
0);
return result;
}
// ggml_v3_flash_attn
struct ggml_v3_tensor * ggml_v3_flash_attn(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * q,
struct ggml_v3_tensor * k,
struct ggml_v3_tensor * v,
bool masked) {
GGML_V3_ASSERT(ggml_v3_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
bool is_node = false;
if (q->grad || k->grad || v->grad) {
is_node = true;
}
//struct ggml_v3_tensor * result = ggml_v3_dup_tensor(ctx, q);
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, GGML_V3_MAX_DIMS, q->ne);
int32_t t = masked ? 1 : 0;
ggml_v3_set_op_params(result, &t, sizeof(t));
result->op = GGML_V3_OP_FLASH_ATTN;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
return result;
}
// ggml_v3_flash_ff
struct ggml_v3_tensor * ggml_v3_flash_ff(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b0,
struct ggml_v3_tensor * b1,
struct ggml_v3_tensor * c0,
struct ggml_v3_tensor * c1) {
GGML_V3_ASSERT(ggml_v3_can_mul_mat(b0, a));
// TODO: more checks
bool is_node = false;
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
is_node = true;
}
//struct ggml_v3_tensor * result = ggml_v3_dup_tensor(ctx, a);
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, GGML_V3_MAX_DIMS, a->ne);
result->op = GGML_V3_OP_FLASH_FF;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b0;
result->src[2] = b1;
result->src[3] = c0;
result->src[4] = c1;
return result;
}
// ggml_v3_flash_attn_back
struct ggml_v3_tensor * ggml_v3_flash_attn_back(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * q,
struct ggml_v3_tensor * k,
struct ggml_v3_tensor * v,
struct ggml_v3_tensor * d,
bool masked) {
GGML_V3_ASSERT(ggml_v3_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
// d shape [D,N,ne2,ne3]
// q shape [D,N,ne2,ne3]
// k shape [D,M,kvne2,ne3]
// v shape [M,D,kvne2,ne3]
const int64_t D = q->ne[0];
const int64_t N = q->ne[1];
const int64_t M = k->ne[1];
const int64_t ne2 = q->ne[2];
const int64_t ne3 = q->ne[3];
const int64_t kvne2 = k->ne[2];
GGML_V3_ASSERT(k->ne[0] == D);
GGML_V3_ASSERT(v->ne[0] == M);
GGML_V3_ASSERT(v->ne[1] == D);
GGML_V3_ASSERT(d->ne[0] == D);
GGML_V3_ASSERT(d->ne[1] == N);
GGML_V3_ASSERT(k->ne[2] == kvne2);
GGML_V3_ASSERT(k->ne[3] == ne3);
GGML_V3_ASSERT(v->ne[2] == kvne2);
GGML_V3_ASSERT(v->ne[3] == ne3);
GGML_V3_ASSERT(d->ne[2] == ne2);
GGML_V3_ASSERT(d->ne[3] == ne3);
GGML_V3_ASSERT(ne2 % kvne2 == 0);
bool is_node = false;
if (q->grad || k->grad || v->grad) {
// when using this operation (in backwards pass) these grads are set.
// we don't want to create (big) grad of our result, so is_node is false.
is_node = false;
}
// store gradients of q, k and v as continuous tensors concatenated in result.
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
const int64_t elem_q = ggml_v3_nelements(q);
const int64_t elem_k = ggml_v3_nelements(k);
const int64_t elem_v = ggml_v3_nelements(v);
enum ggml_v3_type result_type = GGML_V3_TYPE_F32;
GGML_V3_ASSERT(ggml_v3_blck_size(result_type) == 1);
const size_t tsize = ggml_v3_type_size(result_type);
const size_t offs_q = 0;
const size_t offs_k = offs_q + GGML_V3_PAD(elem_q * tsize, GGML_V3_MEM_ALIGN);
const size_t offs_v = offs_k + GGML_V3_PAD(elem_k * tsize, GGML_V3_MEM_ALIGN);
const size_t end = offs_v + GGML_V3_PAD(elem_v * tsize, GGML_V3_MEM_ALIGN);
const size_t nelements = (end + tsize - 1)/tsize;
struct ggml_v3_tensor * result = ggml_v3_new_tensor_1d(ctx, GGML_V3_TYPE_F32, nelements);
int32_t masked_i = masked ? 1 : 0;
ggml_v3_set_op_params(result, &masked_i, sizeof(masked_i));
result->op = GGML_V3_OP_FLASH_ATTN_BACK;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
result->src[3] = d;
return result;
}
// ggml_v3_win_part
struct ggml_v3_tensor * ggml_v3_win_part(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int w) {
GGML_V3_ASSERT(a->ne[3] == 1);
GGML_V3_ASSERT(a->type == GGML_V3_TYPE_F32);
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
// padding
const int px = (w - a->ne[1]%w)%w;
const int py = (w - a->ne[2]%w)%w;
const int npx = (px + a->ne[1])/w;
const int npy = (py + a->ne[2])/w;
const int np = npx*npy;
const int64_t ne[4] = { a->ne[0], w, w, np, };
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 4, ne);
int32_t params[] = { npx, npy, w };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_WIN_PART;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_win_unpart
struct ggml_v3_tensor * ggml_v3_win_unpart(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int w0,
int h0,
int w) {
GGML_V3_ASSERT(a->type == GGML_V3_TYPE_F32);
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F32, 3, ne);
int32_t params[] = { w };
ggml_v3_set_op_params(result, params, sizeof(params));
result->op = GGML_V3_OP_WIN_UNPART;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_get_rel_pos
struct ggml_v3_tensor * ggml_v3_get_rel_pos(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
int qh,
int kh) {
GGML_V3_ASSERT(qh == kh);
GGML_V3_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
bool is_node = false;
if (a->grad) {
GGML_V3_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
struct ggml_v3_tensor * result = ggml_v3_new_tensor(ctx, GGML_V3_TYPE_F16, 3, ne);
result->op = GGML_V3_OP_GET_REL_POS;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_v3_add_rel_pos
static struct ggml_v3_tensor * ggml_v3_add_rel_pos_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * pw,
struct ggml_v3_tensor * ph,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(pw, ph));
GGML_V3_ASSERT(ggml_v3_is_contiguous(a));
GGML_V3_ASSERT(ggml_v3_is_contiguous(pw));
GGML_V3_ASSERT(ggml_v3_is_contiguous(ph));
GGML_V3_ASSERT(ph->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT(pw->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT(pw->ne[3] == a->ne[2]);
GGML_V3_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
GGML_V3_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
bool is_node = false;
if (!inplace && (a->grad || pw->grad || ph->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params_i32(result, 0, inplace ? 1 : 0);
result->op = GGML_V3_OP_ADD_REL_POS;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = pw;
result->src[2] = ph;
return result;
}
struct ggml_v3_tensor * ggml_v3_add_rel_pos(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * pw,
struct ggml_v3_tensor * ph) {
return ggml_v3_add_rel_pos_impl(ctx, a, pw, ph, false);
}
struct ggml_v3_tensor * ggml_v3_add_rel_pos_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * pw,
struct ggml_v3_tensor * ph) {
return ggml_v3_add_rel_pos_impl(ctx, a, pw, ph, true);
}
// gmml_unary
static struct ggml_v3_tensor * ggml_v3_unary_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
enum ggml_v3_unary_op op,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params_i32(result, 0, (int32_t) op);
result->op = GGML_V3_OP_UNARY;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_unary(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
enum ggml_v3_unary_op op) {
return ggml_v3_unary_impl(ctx, a, op, false);
}
struct ggml_v3_tensor * ggml_v3_unary_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
enum ggml_v3_unary_op op) {
return ggml_v3_unary_impl(ctx, a, op, true);
}
// ggml_v3_map_unary
static struct ggml_v3_tensor * ggml_v3_map_unary_impl_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_unary_op_f32_t fun,
bool inplace) {
bool is_node = false;
if (!inplace && a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_V3_OP_MAP_UNARY;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_map_unary_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_unary_op_f32_t fun) {
return ggml_v3_map_unary_impl_f32(ctx, a, fun, false);
}
struct ggml_v3_tensor * ggml_v3_map_unary_inplace_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_unary_op_f32_t fun) {
return ggml_v3_map_unary_impl_f32(ctx, a, fun, true);
}
// ggml_v3_map_binary
static struct ggml_v3_tensor * ggml_v3_map_binary_impl_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_binary_op_f32_t fun,
bool inplace) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(a, b));
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_V3_OP_MAP_BINARY;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_map_binary_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_binary_op_f32_t fun) {
return ggml_v3_map_binary_impl_f32(ctx, a, b, fun, false);
}
struct ggml_v3_tensor * ggml_v3_map_binary_inplace_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_binary_op_f32_t fun) {
return ggml_v3_map_binary_impl_f32(ctx, a, b, fun, true);
}
// ggml_v3_map_custom1_f32
static struct ggml_v3_tensor * ggml_v3_map_custom1_impl_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_custom1_op_f32_t fun,
bool inplace) {
bool is_node = false;
if (!inplace && a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_V3_OP_MAP_CUSTOM1_F32;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_map_custom1_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_custom1_op_f32_t fun) {
return ggml_v3_map_custom1_impl_f32(ctx, a, fun, false);
}
struct ggml_v3_tensor * ggml_v3_map_custom1_inplace_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_custom1_op_f32_t fun) {
return ggml_v3_map_custom1_impl_f32(ctx, a, fun, true);
}
// ggml_v3_map_custom2_f32
static struct ggml_v3_tensor * ggml_v3_map_custom2_impl_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_custom2_op_f32_t fun,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_V3_OP_MAP_CUSTOM2_F32;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_map_custom2_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_custom2_op_f32_t fun) {
return ggml_v3_map_custom2_impl_f32(ctx, a, b, fun, false);
}
struct ggml_v3_tensor * ggml_v3_map_custom2_inplace_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_custom2_op_f32_t fun) {
return ggml_v3_map_custom2_impl_f32(ctx, a, b, fun, true);
}
// ggml_v3_map_custom3_f32
static struct ggml_v3_tensor * ggml_v3_map_custom3_impl_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
struct ggml_v3_tensor * c,
const ggml_v3_custom3_op_f32_t fun,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad || b->grad || c->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
ggml_v3_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_V3_OP_MAP_CUSTOM3_F32;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result;
}
struct ggml_v3_tensor * ggml_v3_map_custom3_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
struct ggml_v3_tensor * c,
const ggml_v3_custom3_op_f32_t fun) {
return ggml_v3_map_custom3_impl_f32(ctx, a, b, c, fun, false);
}
struct ggml_v3_tensor * ggml_v3_map_custom3_inplace_f32(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
struct ggml_v3_tensor * c,
const ggml_v3_custom3_op_f32_t fun) {
return ggml_v3_map_custom3_impl_f32(ctx, a, b, c, fun, true);
}
// ggml_v3_map_custom1
struct ggml_v3_map_custom1_op_params {
ggml_v3_custom1_op_t fun;
int n_tasks;
void * userdata;
};
static struct ggml_v3_tensor * ggml_v3_map_custom1_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_custom1_op_t fun,
int n_tasks,
void * userdata,
bool inplace) {
GGML_V3_ASSERT(n_tasks == GGML_V3_N_TASKS_MAX || n_tasks > 0);
bool is_node = false;
if (!inplace && a->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
struct ggml_v3_map_custom1_op_params params = {
/*.fun =*/ fun,
/*.n_tasks =*/ n_tasks,
/*.userdata =*/ userdata
};
ggml_v3_set_op_params(result, (const void *) &params, sizeof(params));
result->op = GGML_V3_OP_MAP_CUSTOM1;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_v3_tensor * ggml_v3_map_custom1(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_custom1_op_t fun,
int n_tasks,
void * userdata) {
return ggml_v3_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
}
struct ggml_v3_tensor * ggml_v3_map_custom1_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
const ggml_v3_custom1_op_t fun,
int n_tasks,
void * userdata) {
return ggml_v3_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
}
// ggml_v3_map_custom2
struct ggml_v3_map_custom2_op_params {
ggml_v3_custom2_op_t fun;
int n_tasks;
void * userdata;
};
static struct ggml_v3_tensor * ggml_v3_map_custom2_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_custom2_op_t fun,
int n_tasks,
void * userdata,
bool inplace) {
GGML_V3_ASSERT(n_tasks == GGML_V3_N_TASKS_MAX || n_tasks > 0);
bool is_node = false;
if (!inplace && (a->grad || b->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
struct ggml_v3_map_custom2_op_params params = {
/*.fun =*/ fun,
/*.n_tasks =*/ n_tasks,
/*.userdata =*/ userdata
};
ggml_v3_set_op_params(result, (const void *) &params, sizeof(params));
result->op = GGML_V3_OP_MAP_CUSTOM2;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_v3_tensor * ggml_v3_map_custom2(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_custom2_op_t fun,
int n_tasks,
void * userdata) {
return ggml_v3_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
}
struct ggml_v3_tensor * ggml_v3_map_custom2_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
const ggml_v3_custom2_op_t fun,
int n_tasks,
void * userdata) {
return ggml_v3_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
}
// ggml_v3_map_custom3
struct ggml_v3_map_custom3_op_params {
ggml_v3_custom3_op_t fun;
int n_tasks;
void * userdata;
};
static struct ggml_v3_tensor * ggml_v3_map_custom3_impl(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
struct ggml_v3_tensor * c,
const ggml_v3_custom3_op_t fun,
int n_tasks,
void * userdata,
bool inplace) {
GGML_V3_ASSERT(n_tasks == GGML_V3_N_TASKS_MAX || n_tasks > 0);
bool is_node = false;
if (!inplace && (a->grad || b->grad || c->grad)) {
is_node = true;
}
struct ggml_v3_tensor * result = inplace ? ggml_v3_view_tensor(ctx, a) : ggml_v3_dup_tensor(ctx, a);
struct ggml_v3_map_custom3_op_params params = {
/*.fun =*/ fun,
/*.n_tasks =*/ n_tasks,
/*.userdata =*/ userdata
};
ggml_v3_set_op_params(result, (const void *) &params, sizeof(params));
result->op = GGML_V3_OP_MAP_CUSTOM3;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result;
}
struct ggml_v3_tensor * ggml_v3_map_custom3(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
struct ggml_v3_tensor * c,
const ggml_v3_custom3_op_t fun,
int n_tasks,
void * userdata) {
return ggml_v3_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
}
struct ggml_v3_tensor * ggml_v3_map_custom3_inplace(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
struct ggml_v3_tensor * c,
const ggml_v3_custom3_op_t fun,
int n_tasks,
void * userdata) {
return ggml_v3_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
}
// ggml_v3_cross_entropy_loss
struct ggml_v3_tensor * ggml_v3_cross_entropy_loss(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(a, b));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
struct ggml_v3_tensor * result = ggml_v3_new_tensor_1d(ctx, a->type, 1);
result->op = GGML_V3_OP_CROSS_ENTROPY_LOSS;
result->grad = is_node ? ggml_v3_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_v3_cross_entropy_loss_back
struct ggml_v3_tensor * ggml_v3_cross_entropy_loss_back(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * a,
struct ggml_v3_tensor * b,
struct ggml_v3_tensor * c) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(a, b));
GGML_V3_ASSERT(ggml_v3_is_scalar(c));
struct ggml_v3_tensor * result = ggml_v3_dup_tensor(ctx, a);
result->op = GGML_V3_OP_CROSS_ENTROPY_LOSS_BACK;
result->grad = NULL;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result;
}
////////////////////////////////////////////////////////////////////////////////
void ggml_v3_set_param(
struct ggml_v3_context * ctx,
struct ggml_v3_tensor * tensor) {
tensor->is_param = true;
GGML_V3_ASSERT(tensor->grad == NULL);
tensor->grad = ggml_v3_dup_tensor(ctx, tensor);
ggml_v3_format_name(tensor->grad, "%s (grad)", tensor->name);
}
// ggml_v3_compute_forward_dup
static void ggml_v3_compute_forward_dup_same_cont(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_nelements(dst) == ggml_v3_nelements(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst) && ggml_v3_is_contiguous(src0));
GGML_V3_ASSERT(src0->type == dst->type);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const size_t nb00 = src0->nb[0];
const size_t nb0 = dst->nb[0];
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
// parallelize by elements
const int ne = ggml_v3_nelements(dst);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = MIN(ie0 + dr, ne);
if (ie0 < ie1) {
memcpy(
((char *) dst->data + ie0*nb0),
((char *) src0->data + ie0*nb00),
(ie1 - ie0) * ggml_v3_type_size(src0->type));
}
}
static void ggml_v3_compute_forward_dup_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_nelements(dst) == ggml_v3_nelements(src0));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_UNARY_OP_LOCALS
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
if (ggml_v3_is_contiguous(src0) && ggml_v3_is_contiguous(dst) && src0->type == dst->type) {
ggml_v3_compute_forward_dup_same_cont(params, src0, dst);
return;
}
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type &&
ne00 == ne0 &&
nb00 == ggml_v3_type_size(src0->type) && nb0 == ggml_v3_type_size(dst->type)) {
// copy by rows
const size_t rs = ne00*nb00;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
memcpy(
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
rs);
}
}
}
return;
}
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
if (ggml_v3_is_contiguous(dst)) {
if (nb00 == sizeof(ggml_v3_fp16_t)) {
if (dst->type == GGML_V3_TYPE_F16) {
size_t id = 0;
const size_t rs = ne00 * nb00;
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
memcpy(dst_ptr + id, src0_ptr, rs);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (dst->type == GGML_V3_TYPE_F32) {
size_t id = 0;
float * dst_ptr = (float *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
for (int i00 = 0; i00 < ne00; i00++) {
dst_ptr[id] = GGML_V3_FP16_TO_FP32(src0_ptr[i00]);
id++;
}
}
id += ne00 * (ne01 - ir1);
}
}
} else if (type_traits[dst->type].from_float) {
ggml_v3_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
size_t rs = nb0 * (ne00 / ggml_v3_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
for (int i00 = 0; i00 < ne00; i00++) {
src0_f32[i00] = GGML_V3_FP16_TO_FP32(src0_ptr[i00]);
}
quantize_row_q(src0_f32, dst_ptr + id, ne00);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else {
GGML_V3_ASSERT(false); // TODO: implement
}
} else {
//printf("%s: this is not optimal - fix me\n", __func__);
if (dst->type == GGML_V3_TYPE_F32) {
size_t id = 0;
float * dst_ptr = (float *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
dst_ptr[id] = GGML_V3_FP16_TO_FP32(*src0_ptr);
id++;
}
}
id += ne00 * (ne01 - ir1);
}
}
} else if (dst->type == GGML_V3_TYPE_F16) {
size_t id = 0;
ggml_v3_fp16_t * dst_ptr = (ggml_v3_fp16_t *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
dst_ptr[id] = *src0_ptr;
id++;
}
}
id += ne00 * (ne01 - ir1);
}
}
} else {
GGML_V3_ASSERT(false); // TODO: implement
}
}
return;
}
// dst counters
int64_t i10 = 0;
int64_t i11 = 0;
int64_t i12 = 0;
int64_t i13 = 0;
if (dst->type == GGML_V3_TYPE_F16) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
memcpy(dst_ptr, src0_ptr, sizeof(ggml_v3_fp16_t));
if (++i10 == ne00) {
i10 = 0;
if (++i11 == ne01) {
i11 = 0;
if (++i12 == ne02) {
i12 = 0;
if (++i13 == ne03) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else if (dst->type == GGML_V3_TYPE_F32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(float *) dst_ptr = GGML_V3_FP16_TO_FP32(*(const ggml_v3_fp16_t *) src0_ptr);
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else {
GGML_V3_ASSERT(false); // TODO: implement
}
}
static void ggml_v3_compute_forward_dup_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_nelements(dst) == ggml_v3_nelements(src0));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_UNARY_OP_LOCALS
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
if (ggml_v3_is_contiguous(src0) && ggml_v3_is_contiguous(dst) && src0->type == dst->type) {
ggml_v3_compute_forward_dup_same_cont(params, src0, dst);
return;
}
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type &&
ne00 == ne0 &&
nb00 == ggml_v3_type_size(src0->type) && nb0 == ggml_v3_type_size(dst->type)) {
// copy by rows
const size_t rs = ne00*nb00;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
memcpy(
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
rs);
}
}
}
return;
}
if (ggml_v3_is_contiguous(dst)) {
// TODO: simplify
if (nb00 == sizeof(float)) {
if (dst->type == GGML_V3_TYPE_F32) {
size_t id = 0;
const size_t rs = ne00 * nb00;
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
memcpy(dst_ptr + id, src0_ptr, rs);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (type_traits[dst->type].from_float) {
ggml_v3_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
size_t id = 0;
size_t rs = nb0 * (ne00 / ggml_v3_blck_size(dst->type));
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else {
GGML_V3_ASSERT(false); // TODO: implement
}
} else {
//printf("%s: this is not optimal - fix me\n", __func__);
if (dst->type == GGML_V3_TYPE_F32) {
size_t id = 0;
float * dst_ptr = (float *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
dst_ptr[id] = *src0_ptr;
id++;
}
}
id += ne00 * (ne01 - ir1);
}
}
} else if (dst->type == GGML_V3_TYPE_F16) {
size_t id = 0;
ggml_v3_fp16_t * dst_ptr = (ggml_v3_fp16_t *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
dst_ptr[id] = GGML_V3_FP32_TO_FP16(*src0_ptr);
id++;
}
}
id += ne00 * (ne01 - ir1);
}
}
} else {
GGML_V3_ASSERT(false); // TODO: implement
}
}
return;
}
// dst counters
int64_t i10 = 0;
int64_t i11 = 0;
int64_t i12 = 0;
int64_t i13 = 0;
if (dst->type == GGML_V3_TYPE_F32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
memcpy(dst_ptr, src0_ptr, sizeof(float));
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else if (dst->type == GGML_V3_TYPE_F16) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(ggml_v3_fp16_t *) dst_ptr = GGML_V3_FP32_TO_FP16(*(const float *) src0_ptr);
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else {
GGML_V3_ASSERT(false); // TODO: implement
}
}
// A simplified version of ggml_v3_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
static void ggml_v3_compute_forward_dup_bytes(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_nelements(dst) == ggml_v3_nelements(src0));
GGML_V3_ASSERT(src0->type == dst->type);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
if (ggml_v3_is_contiguous(src0) && ggml_v3_is_contiguous(dst)) {
ggml_v3_compute_forward_dup_same_cont(params, src0, dst);
return;
}
GGML_V3_TENSOR_UNARY_OP_LOCALS;
const size_t type_size = ggml_v3_type_size(src0->type);
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type &&
ne00 == ne0 &&
nb00 == type_size && nb0 == type_size) {
// copy by rows
const size_t rs = ne00 * type_size;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
memcpy(
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
rs);
}
}
}
return;
}
if (ggml_v3_is_contiguous(dst)) {
size_t id = 0;
char * dst_ptr = (char *) dst->data;
const size_t rs = ne00 * type_size;
if (nb00 == type_size) {
// src0 is contigous on first dimension, copy by rows
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
memcpy(dst_ptr + id, src0_ptr, rs);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else {
//printf("%s: this is not optimal - fix me\n", __func__);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
memcpy(dst_ptr + id, src0_ptr, type_size);
id += type_size;
}
}
id += rs * (ne01 - ir1);
}
}
}
return;
}
// dst counters
int64_t i10 = 0;
int64_t i11 = 0;
int64_t i12 = 0;
int64_t i13 = 0;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
memcpy(dst_ptr, src0_ptr, type_size);
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
}
static void ggml_v3_compute_forward_dup(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
if (src0->type == dst->type) {
ggml_v3_compute_forward_dup_bytes(params, src0, dst);
return;
}
switch (src0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_dup_f16(params, src0, dst);
} break;
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_dup_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_add
static void ggml_v3_compute_forward_add_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_can_repeat(src1, src0) && ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
GGML_V3_ASSERT( nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
if (nb10 == sizeof(float)) {
for (int ir = ir0; ir < ir1; ++ir) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
const int64_t nr0 = ne00 / ne10;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
for (int64_t r = 0; r < nr0; ++r) {
#ifdef GGML_USE_ACCELERATE
vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
#else
ggml_v3_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
#endif
}
}
} else {
// src1 is not contiguous
for (int ir = ir0; ir < ir1; ++ir) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int64_t i10 = i0 % ne10;
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
}
}
}
}
static void ggml_v3_compute_forward_add_f16_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, src1) && ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
GGML_V3_ASSERT(src0->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
if (dst->type == GGML_V3_TYPE_F32) {
GGML_V3_ASSERT( nb0 == sizeof(float));
}
else {
GGML_V3_ASSERT(dst->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT( nb0 == sizeof(ggml_v3_fp16_t));
}
GGML_V3_ASSERT(nb00 == sizeof(ggml_v3_fp16_t));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
if (nb10 == sizeof(float)) {
if (dst->type == GGML_V3_TYPE_F16) {
for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
ggml_v3_fp16_t * dst_ptr = (ggml_v3_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
for (int i = 0; i < ne0; i++) {
dst_ptr[i] = GGML_V3_FP32_TO_FP16(GGML_V3_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
}
}
} else {
for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
for (int i = 0; i < ne0; i++) {
dst_ptr[i] = GGML_V3_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
}
}
}
}
else {
// src1 is not contiguous
GGML_V3_ASSERT(false);
}
}
static void ggml_v3_compute_forward_add_f16_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, src1) && ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
GGML_V3_ASSERT(src0->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(dst->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT( nb0 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nb00 == sizeof(ggml_v3_fp16_t));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
if (nb10 == sizeof(ggml_v3_fp16_t)) {
for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
ggml_v3_fp16_t * dst_ptr = (ggml_v3_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
ggml_v3_fp16_t * src1_ptr = (ggml_v3_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
for (int i = 0; i < ne0; i++) {
dst_ptr[i] = GGML_V3_FP32_TO_FP16(GGML_V3_FP16_TO_FP32(src0_ptr[i]) + GGML_V3_FP16_TO_FP32(src1_ptr[i]));
}
}
}
else {
// src1 is not contiguous
GGML_V3_ASSERT(false);
}
}
static void ggml_v3_compute_forward_add_q_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, src1) && ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_v3_type type = src0->type;
const enum ggml_v3_type dtype = dst->type;
ggml_v3_to_float_t const dequantize_row_q = type_traits[type].to_float;
ggml_v3_from_float_t const quantize_row_q = type_traits[dtype].from_float;
// we don't support permuted src0 or src1
GGML_V3_ASSERT(nb00 == ggml_v3_type_size(type));
GGML_V3_ASSERT(nb10 == sizeof(float));
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 <= nb1);
GGML_V3_ASSERT(nb1 <= nb2);
GGML_V3_ASSERT(nb2 <= nb3);
GGML_V3_ASSERT(ggml_v3_is_quantized(src0->type));
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
// src1 and dst are same shape as src0 => same indices
const int i13 = i03;
const int i12 = i02;
const int i11 = i01;
const int i3 = i03;
const int i2 = i02;
const int i1 = i01;
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
assert(ne00 % 32 == 0);
// unquantize row from src0 to temp buffer
dequantize_row_q(src0_row, wdata, ne00);
// add src1
ggml_v3_vec_acc_f32(ne00, wdata, src1_row);
// quantize row to dst
if (quantize_row_q != NULL) {
quantize_row_q(wdata, dst_row, ne00);
} else {
memcpy(dst_row, wdata, ne0*nb0);
}
}
}
static void ggml_v3_compute_forward_add(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_add_f32(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F16:
{
if (src1->type == GGML_V3_TYPE_F16) {
ggml_v3_compute_forward_add_f16_f16(params, src0, src1, dst);
}
else if (src1->type == GGML_V3_TYPE_F32) {
ggml_v3_compute_forward_add_f16_f32(params, src0, src1, dst);
}
else {
GGML_V3_ASSERT(false);
}
} break;
case GGML_V3_TYPE_Q4_0:
case GGML_V3_TYPE_Q4_1:
case GGML_V3_TYPE_Q5_0:
case GGML_V3_TYPE_Q5_1:
case GGML_V3_TYPE_Q8_0:
case GGML_V3_TYPE_Q2_K:
case GGML_V3_TYPE_Q3_K:
case GGML_V3_TYPE_Q4_K:
case GGML_V3_TYPE_Q5_K:
case GGML_V3_TYPE_Q6_K:
case GGML_V3_TYPE_IQ2_XXS:
case GGML_V3_TYPE_IQ2_XS:
{
ggml_v3_compute_forward_add_q_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_add1
static void ggml_v3_compute_forward_add1_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
GGML_V3_ASSERT(ggml_v3_is_scalar(src1));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_UNARY_OP_LOCALS
GGML_V3_ASSERT( nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
#ifdef GGML_USE_ACCELERATE
UNUSED(ggml_v3_vec_add1_f32);
vDSP_vadd(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data), 0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0);
#else
ggml_v3_vec_add1_f32(ne0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
*(float *) src1->data);
#endif
}
}
static void ggml_v3_compute_forward_add1_f16_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
GGML_V3_ASSERT(ggml_v3_is_scalar(src1));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// scalar to add
const float v = *(float *) src1->data;
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_UNARY_OP_LOCALS
GGML_V3_ASSERT(src0->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT(dst->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT( nb0 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nb00 == sizeof(ggml_v3_fp16_t));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
ggml_v3_fp16_t * dst_ptr = (ggml_v3_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
for (int i = 0; i < ne0; i++) {
dst_ptr[i] = GGML_V3_FP32_TO_FP16(GGML_V3_FP16_TO_FP32(src0_ptr[i]) + v);
}
}
}
static void ggml_v3_compute_forward_add1_f16_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
GGML_V3_ASSERT(ggml_v3_is_scalar(src1));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// scalar to add
const float v = GGML_V3_FP16_TO_FP32(*(ggml_v3_fp16_t *) src1->data);
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_UNARY_OP_LOCALS
GGML_V3_ASSERT(src0->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(dst->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT( nb0 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nb00 == sizeof(ggml_v3_fp16_t));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
ggml_v3_fp16_t * dst_ptr = (ggml_v3_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
ggml_v3_fp16_t * src0_ptr = (ggml_v3_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
for (int i = 0; i < ne0; i++) {
dst_ptr[i] = GGML_V3_FP32_TO_FP16(GGML_V3_FP16_TO_FP32(src0_ptr[i]) + v);
}
}
}
static void ggml_v3_compute_forward_add1_q_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
GGML_V3_ASSERT(ggml_v3_is_scalar(src1));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// scalar to add
const float v = *(float *) src1->data;
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_UNARY_OP_LOCALS
const enum ggml_v3_type type = src0->type;
ggml_v3_to_float_t const dequantize_row_q = type_traits[type].to_float;
ggml_v3_from_float_t const quantize_row_q = type_traits[type].from_float;
// we don't support permuted src0
GGML_V3_ASSERT(nb00 == ggml_v3_type_size(type));
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 <= nb1);
GGML_V3_ASSERT(nb1 <= nb2);
GGML_V3_ASSERT(nb2 <= nb3);
GGML_V3_ASSERT(ggml_v3_is_quantized(src0->type));
GGML_V3_ASSERT(dst->type == src0->type);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
for (int ir = ir0; ir < ir1; ++ir) {
// src0 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
assert(ne0 % 32 == 0);
// unquantize row from src0 to temp buffer
dequantize_row_q(src0_row, wdata, ne0);
// add src1
ggml_v3_vec_acc1_f32(ne0, wdata, v);
// quantize row to dst
quantize_row_q(wdata, dst_row, ne0);
}
}
static void ggml_v3_compute_forward_add1(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_add1_f32(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F16:
{
if (src1->type == GGML_V3_TYPE_F16) {
ggml_v3_compute_forward_add1_f16_f16(params, src0, src1, dst);
}
else if (src1->type == GGML_V3_TYPE_F32) {
ggml_v3_compute_forward_add1_f16_f32(params, src0, src1, dst);
}
else {
GGML_V3_ASSERT(false);
}
} break;
case GGML_V3_TYPE_Q4_0:
case GGML_V3_TYPE_Q4_1:
case GGML_V3_TYPE_Q5_0:
case GGML_V3_TYPE_Q5_1:
case GGML_V3_TYPE_Q8_0:
case GGML_V3_TYPE_Q8_1:
case GGML_V3_TYPE_Q2_K:
case GGML_V3_TYPE_Q3_K:
case GGML_V3_TYPE_Q4_K:
case GGML_V3_TYPE_Q5_K:
case GGML_V3_TYPE_Q6_K:
case GGML_V3_TYPE_IQ2_XXS:
case GGML_V3_TYPE_IQ2_XS:
{
ggml_v3_compute_forward_add1_q_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_acc
static void ggml_v3_compute_forward_acc_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst) && ggml_v3_is_contiguous(src0));
// view src0 and dst with these strides and data offset inbytes during acc
// nb0 is implicitly element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2];
size_t offset = ((int32_t *) dst->op_params)[3];
bool inplace = (bool) ((int32_t *) dst->op_params)[4];
if (!inplace && (params->type == GGML_V3_TASK_INIT)) {
// memcpy needs to be synchronized across threads to avoid race conditions.
// => do it in INIT phase
memcpy(
((char *) dst->data),
((char *) src0->data),
ggml_v3_nbytes(dst));
}
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src1);
const int nc = src1->ne[0];
GGML_V3_TENSOR_LOCALS(int64_t, ne1, src1, ne)
GGML_V3_TENSOR_LOCALS(size_t, nb1, src1, nb)
// src0 and dst as viewed during acc
const size_t nb0 = ggml_v3_element_size(src0);
const size_t nb00 = nb0;
const size_t nb01 = nb1;
const size_t nb02 = nb2;
const size_t nb03 = nb3;
GGML_V3_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_v3_nbytes(dst));
GGML_V3_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_v3_nbytes(src0));
GGML_V3_ASSERT(nb10 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 and dst are viewed with shape of src1 and offset
// => same indices
const int i3 = ir/(ne12*ne11);
const int i2 = (ir - i3*ne12*ne11)/ne11;
const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
#ifdef GGML_USE_ACCELERATE
vDSP_vadd(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
#else
ggml_v3_vec_add_f32(nc,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
#endif
}
}
static void ggml_v3_compute_forward_acc(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_acc_f32(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F16:
case GGML_V3_TYPE_Q4_0:
case GGML_V3_TYPE_Q4_1:
case GGML_V3_TYPE_Q5_0:
case GGML_V3_TYPE_Q5_1:
case GGML_V3_TYPE_Q8_0:
case GGML_V3_TYPE_Q8_1:
case GGML_V3_TYPE_Q2_K:
case GGML_V3_TYPE_Q3_K:
case GGML_V3_TYPE_Q4_K:
case GGML_V3_TYPE_Q5_K:
case GGML_V3_TYPE_Q6_K:
case GGML_V3_TYPE_IQ2_XXS:
case GGML_V3_TYPE_IQ2_XS:
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_sub
static void ggml_v3_compute_forward_sub_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, src1) && ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
GGML_V3_ASSERT( nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
for (int ir = 0; ir < nr; ++ir) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
#ifdef GGML_USE_ACCELERATE
vDSP_vsub(
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0);
#else
ggml_v3_vec_sub_f32(ne0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
#endif
// }
// }
}
} else {
// src1 is not contiguous
for (int ir = 0; ir < nr; ++ir) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
for (int i0 = 0; i0 < ne0; i0++) {
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
}
}
}
}
static void ggml_v3_compute_forward_sub(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_sub_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_mul
static void ggml_v3_compute_forward_mul_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_can_repeat(src1, src0) && ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
#ifdef GGML_USE_CLBLAST
if (src1->backend == GGML_V3_BACKEND_GPU) {
// TODO: OpenCL kernel support full broadcast
GGML_V3_ASSERT(ggml_v3_can_repeat_rows(src1, src0));
if (ith == 0) {
ggml_v3_cl_mul(src0, src1, dst);
}
return;
}
#endif
const int64_t nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
GGML_V3_ASSERT( nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
const int64_t nr0 = ne00 / ne10;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
for (int64_t r = 0 ; r < nr0; ++r) {
#ifdef GGML_USE_ACCELERATE
UNUSED(ggml_v3_vec_mul_f32);
vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
#else
ggml_v3_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
#endif
}
}
} else {
// src1 is not contiguous
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
for (int64_t i0 = 0; i0 < ne00; ++i0) {
const int64_t i10 = i0 % ne10;
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
}
}
}
}
static void ggml_v3_compute_forward_mul(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32 && "only f32 src1 supported for now");
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_mul_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_div
static void ggml_v3_compute_forward_div_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_can_repeat(src1, src0) && ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_v3_nrows(src0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
GGML_V3_ASSERT( nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
const int64_t nr0 = ne00 / ne10;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
for (int64_t r = 0; r < nr0; ++r) {
#ifdef GGML_USE_ACCELERATE
UNUSED(ggml_v3_vec_div_f32);
vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
#else
ggml_v3_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
#endif
}
}
} else {
// src1 is not contiguous
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
for (int64_t i0 = 0; i0 < ne00; ++i0) {
const int64_t i10 = i0 % ne10;
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
}
}
}
}
static void ggml_v3_compute_forward_div(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_div_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_sqr
static void ggml_v3_compute_forward_sqr_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert( dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_sqr_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_sqr(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_sqr_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_sqrt
static void ggml_v3_compute_forward_sqrt_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert( dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_sqrt_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_sqrt(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_sqrt_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_log
static void ggml_v3_compute_forward_log_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(params->ith == 0);
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
GGML_V3_ASSERT( dst->nb[0] == sizeof(float));
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_log_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_log(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_log_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_sum
static void ggml_v3_compute_forward_sum_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_is_scalar(dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
assert(ggml_v3_is_scalar(dst));
assert(src0->nb[0] == sizeof(float));
GGML_V3_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_V3_TENSOR_LOCALS(size_t, nb0, src0, nb)
ggml_v3_float sum = 0;
ggml_v3_float row_sum = 0;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
ggml_v3_vec_sum_f32_ggf(ne00,
&row_sum,
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
sum += row_sum;
}
}
}
((float *) dst->data)[0] = sum;
}
static void ggml_v3_compute_forward_sum_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_is_scalar(dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
assert(src0->nb[0] == sizeof(ggml_v3_fp16_t));
GGML_V3_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_V3_TENSOR_LOCALS(size_t, nb0, src0, nb)
float sum = 0;
float row_sum = 0;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
ggml_v3_vec_sum_f16_ggf(ne00,
&row_sum,
(ggml_v3_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
sum += row_sum;
}
}
}
((ggml_v3_fp16_t *) dst->data)[0] = GGML_V3_FP32_TO_FP16(sum);
}
static void ggml_v3_compute_forward_sum(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_sum_f32(params, src0, dst);
} break;
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_sum_f16(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_sum_rows
static void ggml_v3_compute_forward_sum_rows_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
GGML_V3_ASSERT(dst->nb[0] == sizeof(float));
GGML_V3_TENSOR_UNARY_OP_LOCALS
GGML_V3_ASSERT(ne0 == 1);
GGML_V3_ASSERT(ne1 == ne01);
GGML_V3_ASSERT(ne2 == ne02);
GGML_V3_ASSERT(ne3 == ne03);
for (int64_t i3 = 0; i3 < ne03; i3++) {
for (int64_t i2 = 0; i2 < ne02; i2++) {
for (int64_t i1 = 0; i1 < ne01; i1++) {
float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
ggml_v3_vec_sum_f32(ne00, &row_sum, src_row);
dst_row[0] = row_sum;
}
}
}
}
static void ggml_v3_compute_forward_sum_rows(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_sum_rows_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_mean
static void ggml_v3_compute_forward_mean_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
assert(src0->nb[0] == sizeof(float));
GGML_V3_TENSOR_UNARY_OP_LOCALS
assert(ne0 == 1);
assert(ne1 == ne01);
assert(ne2 == ne02);
assert(ne3 == ne03);
UNUSED(ne0);
UNUSED(ne1);
UNUSED(ne2);
UNUSED(ne3);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
ggml_v3_vec_sum_f32(ne00,
(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
*(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
}
}
}
}
static void ggml_v3_compute_forward_mean(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_mean_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_argmax
static void ggml_v3_compute_forward_argmax_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
assert(src0->nb[0] == sizeof(float));
assert(dst->nb[0] == sizeof(float));
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const size_t nb01 = src0->nb[1];
const size_t nb0 = dst->nb[0];
for (int64_t i1 = 0; i1 < ne01; i1++) {
float * src = (float *) ((char *) src0->data + i1*nb01);
int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
int v = 0;
ggml_v3_vec_argmax_f32(ne00, &v, src);
dst_[0] = v;
}
}
static void ggml_v3_compute_forward_argmax(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_argmax_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_repeat
static void ggml_v3_compute_forward_repeat_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(params->ith == 0);
GGML_V3_ASSERT(ggml_v3_can_repeat(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_UNARY_OP_LOCALS
// guaranteed to be an integer due to the check in ggml_v3_can_repeat
const int nr0 = (int)(ne0/ne00);
const int nr1 = (int)(ne1/ne01);
const int nr2 = (int)(ne2/ne02);
const int nr3 = (int)(ne3/ne03);
// TODO: support for transposed / permuted tensors
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
// TODO: maybe this is not optimal?
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne03; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne02; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne01; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
ggml_v3_vec_cpy_f32(ne00,
(float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
(float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
}
}
}
}
}
}
}
}
static void ggml_v3_compute_forward_repeat_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(params->ith == 0);
GGML_V3_ASSERT(ggml_v3_can_repeat(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_UNARY_OP_LOCALS
// guaranteed to be an integer due to the check in ggml_v3_can_repeat
const int nr0 = (int)(ne0/ne00);
const int nr1 = (int)(ne1/ne01);
const int nr2 = (int)(ne2/ne02);
const int nr3 = (int)(ne3/ne03);
// TODO: support for transposed / permuted tensors
GGML_V3_ASSERT(nb0 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nb00 == sizeof(ggml_v3_fp16_t));
// TODO: maybe this is not optimal?
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne03; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne02; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne01; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
ggml_v3_fp16_t * y = (ggml_v3_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
ggml_v3_fp16_t * x = (ggml_v3_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
// ggml_v3_vec_cpy_f16(ne00, y, x)
for (int i = 0; i < ne00; ++i) {
y[i] = x[i];
}
}
}
}
}
}
}
}
}
static void ggml_v3_compute_forward_repeat(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F16:
case GGML_V3_TYPE_I16:
{
ggml_v3_compute_forward_repeat_f16(params, src0, dst);
} break;
case GGML_V3_TYPE_F32:
case GGML_V3_TYPE_I32:
{
ggml_v3_compute_forward_repeat_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_repeat_back
static void ggml_v3_compute_forward_repeat_back_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(params->ith == 0);
GGML_V3_ASSERT(ggml_v3_can_repeat(dst, src0));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_UNARY_OP_LOCALS
// guaranteed to be an integer due to the check in ggml_v3_can_repeat
const int nr0 = (int)(ne00/ne0);
const int nr1 = (int)(ne01/ne1);
const int nr2 = (int)(ne02/ne2);
const int nr3 = (int)(ne03/ne3);
// TODO: support for transposed / permuted tensors
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
if (ggml_v3_is_contiguous(dst)) {
ggml_v3_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
} else {
for (int k3 = 0; k3 < ne3; k3++) {
for (int k2 = 0; k2 < ne2; k2++) {
for (int k1 = 0; k1 < ne1; k1++) {
ggml_v3_vec_set_f32(ne0,
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
0);
}
}
}
}
// TODO: maybe this is not optimal?
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne3; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne2; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne1; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
ggml_v3_vec_acc_f32(ne0,
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
}
}
}
}
}
}
}
}
static void ggml_v3_compute_forward_repeat_back(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_repeat_back_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_concat
static void ggml_v3_compute_forward_concat_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_V3_TENSOR_BINARY_OP_LOCALS
// TODO: support for transposed / permuted tensors
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
GGML_V3_ASSERT(nb10 == sizeof(float));
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ith; i2 < ne2; i2 += nth) {
if (i2 < ne02) { // src0
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
*y = *x;
}
}
} // src1
else {
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
*y = *x;
}
}
}
}
}
}
static void ggml_v3_compute_forward_concat(
const struct ggml_v3_compute_params* params,
const struct ggml_v3_tensor* src0,
const struct ggml_v3_tensor* src1,
struct ggml_v3_tensor* dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
case GGML_V3_TYPE_I32:
{
ggml_v3_compute_forward_concat_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_abs
static void ggml_v3_compute_forward_abs_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_abs_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_abs(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_abs_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_sgn
static void ggml_v3_compute_forward_sgn_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_sgn_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_sgn(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_sgn_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_neg
static void ggml_v3_compute_forward_neg_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_neg_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_neg(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_neg_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_step
static void ggml_v3_compute_forward_step_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_step_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_step(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_step_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_tanh
static void ggml_v3_compute_forward_tanh_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_tanh_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_tanh(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_tanh_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_elu
static void ggml_v3_compute_forward_elu_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_elu_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_elu(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_elu_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_relu
static void ggml_v3_compute_forward_relu_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_relu_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_relu(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_relu_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_gelu
static void ggml_v3_compute_forward_gelu_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_v3_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_v3_vec_gelu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_v3_compute_forward_gelu(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_gelu_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_gelu_quick
static void ggml_v3_compute_forward_gelu_quick_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_v3_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_v3_vec_gelu_quick_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_v3_compute_forward_gelu_quick(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_gelu_quick_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_silu
static void ggml_v3_compute_forward_silu_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_v3_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_v3_vec_silu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_v3_compute_forward_silu(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_silu_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_leaky_relu
static void ggml_v3_compute_forward_leaky_relu_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
float negative_slope;
memcpy(&negative_slope, dst->op_params, sizeof(float));
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_v3_vec_leaky_relu_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
}
}
static void ggml_v3_compute_forward_leaky_relu(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_leaky_relu_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_silu_back
static void ggml_v3_compute_forward_silu_back_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * grad,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(grad));
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous_except_dim_1(dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, grad));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_v3_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_v3_vec_silu_backward_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])),
(float *) ((char *) grad->data + i1*(grad->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_v3_compute_forward_silu_back(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * grad,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_silu_back_f32(params, src0, grad, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_norm
static void ggml_v3_compute_forward_norm_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_V3_TENSOR_UNARY_OP_LOCALS
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_V3_ASSERT(eps > 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_v3_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_v3_float)x[i00];
}
float mean = sum/ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
ggml_v3_float sum2 = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
float v = x[i00] - mean;
y[i00] = v;
sum2 += (ggml_v3_float)(v*v);
}
float variance = sum2/ne00;
const float scale = 1.0f/sqrtf(variance + eps);
ggml_v3_vec_scale_f32(ne00, y, scale);
}
}
}
}
static void ggml_v3_compute_forward_norm(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_norm_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_group_rms_norm
static void ggml_v3_compute_forward_rms_norm_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_V3_TENSOR_UNARY_OP_LOCALS
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_V3_ASSERT(eps > 0.0f);
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_v3_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_v3_float)(x[i00] * x[i00]);
}
const float mean = sum/ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }
const float scale = 1.0f/sqrtf(mean + eps);
ggml_v3_vec_scale_f32(ne00, y, scale);
}
}
}
}
static void ggml_v3_compute_forward_rms_norm(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_rms_norm_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
static void ggml_v3_compute_forward_rms_norm_back_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst) && ggml_v3_are_same_shape(src0, src1));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_V3_TENSOR_BINARY_OP_LOCALS
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
// src1 is same shape as src0 => same indices
const int64_t i11 = i01;
const int64_t i12 = i02;
const int64_t i13 = i03;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
ggml_v3_float sum_xx = 0.0;
ggml_v3_float sum_xdz = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum_xx += (ggml_v3_float)(x[i00] * x[i00]);
sum_xdz += (ggml_v3_float)(x[i00] * dz[i00]);
}
//const float mean = (float)(sum_xx)/ne00;
const float mean_eps = (float)(sum_xx)/ne00 + eps;
const float sum_eps = (float)(sum_xx) + eps*ne00;
//const float mean_xdz = (float)(sum_xdz)/ne00;
// we could cache rms from forward pass to improve performance.
// to do this implement ggml_v3_rms and compose ggml_v3_rms_norm using ggml_v3_rms.
//const float rms = sqrtf(mean_eps);
const float rrms = 1.0f / sqrtf(mean_eps);
//const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
{
// z = rms_norm(x)
//
// rms_norm(src0) =
// scale(
// src0,
// div(
// 1,
// sqrt(
// add(
// scale(
// sum(
// sqr(
// src0)),
// (1.0/N)),
// eps))));
// postorder:
// ## op args grad
// 00 param src0 grad[#00]
// 01 const 1
// 02 sqr (#00) grad[#02]
// 03 sum (#02) grad[#03]
// 04 const 1/N
// 05 scale (#03, #04) grad[#05]
// 06 const eps
// 07 add (#05, #06) grad[#07]
// 08 sqrt (#07) grad[#08]
// 09 div (#01,#08) grad[#09]
// 10 scale (#00,#09) grad[#10]
//
// backward pass, given grad[#10]
// #10: scale
// grad[#00] += scale(grad[#10],#09)
// grad[#09] += sum(mul(grad[#10],#00))
// #09: div
// grad[#08] += neg(mul(grad[#09], div(#09,#08)))
// #08: sqrt
// grad[#07] += mul(grad[#08], div(0.5, #08))
// #07: add
// grad[#05] += grad[#07]
// #05: scale
// grad[#03] += scale(grad[#05],#04)
// #03: sum
// grad[#02] += repeat(grad[#03], #02)
// #02:
// grad[#00] += scale(mul(#00, grad[#02]), 2.0)
//
// substitute and simplify:
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
// grad[#02] = repeat(grad[#03], #02)
// grad[#02] = repeat(scale(grad[#05],#04), #02)
// grad[#02] = repeat(scale(grad[#07],#04), #02)
// grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
// grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
// grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
// grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
// grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
// grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
// grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
// grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
// grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
// grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
// a = b*c + d*e
// a = b*c*f/f + d*e*f/f
// a = (b*c*f + d*e*f)*(1/f)
// a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
// a = (b + d*e/c)*c
// b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
// a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
// a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
// a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
// a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
// a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
// a = (dz + x*div(-mean_xdz,mean_eps))*rrms
// grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
// grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
// dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
}
// dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
// post-order:
// dx := x
// dx := scale(dx,-mean_xdz/mean_eps)
// dx := add(dx, dz)
// dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
ggml_v3_vec_cpy_f32 (ne00, dx, x);
// ggml_v3_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_v3_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
ggml_v3_vec_acc_f32 (ne00, dx, dz);
ggml_v3_vec_scale_f32(ne00, dx, rrms);
}
}
}
}
static void ggml_v3_compute_forward_rms_norm_back(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_rms_norm_back_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_group_norm
static void ggml_v3_compute_forward_group_norm_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_V3_TENSOR_UNARY_OP_LOCALS
const float eps = 1e-6f; // TODO: make this a parameter
// TODO: optimize
int n_channels = src0->ne[2];
int n_groups = dst->op_params[0];
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
for (int i = ith; i < n_groups; i+=nth) {
int start = i * n_channels_per_group;
int end = start + n_channels_per_group;
if (end > n_channels) {
end = n_channels;
}
int step = end - start;
for (int64_t i03 = 0; i03 < ne03; i03++) {
ggml_v3_float sum = 0.0;
for (int64_t i02 = start; i02 < end; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_v3_float)x[i00];
}
}
}
float mean = sum / (ne00 * ne01 * step);
ggml_v3_float sum2 = 0.0;
for (int64_t i02 = start; i02 < end; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
for (int64_t i00 = 0; i00 < ne00; i00++) {
float v = x[i00] - mean;
y[i00] = v;
sum2 += (ggml_v3_float)(v * v);
}
}
}
float variance = sum2 / (ne00 * ne01 * step);
const float scale = 1.0f / sqrtf(variance + eps);
for (int64_t i02 = start; i02 < end; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
ggml_v3_vec_scale_f32(ne00, y, scale);
}
}
}
}
}
static void ggml_v3_compute_forward_group_norm(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_group_norm_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_mul_mat
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
// helper function to determine if it is better to use BLAS or not
// for large matrices, BLAS is faster
static bool ggml_v3_compute_forward_mul_mat_use_blas(struct ggml_v3_tensor * dst) {
const struct ggml_v3_tensor * src0 = dst->src[0];
const struct ggml_v3_tensor * src1 = dst->src[1];
//const int64_t ne00 = src0->ne[0];
//const int64_t ne01 = src0->ne[1];
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
// NOTE: with GGML_V3_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
// all the experts for each batch element and the processing would become incredibly slow
// TODO: find the optimal values for these
if (dst->op != GGML_V3_OP_MUL_MAT_ID &&
ggml_v3_is_contiguous(src0) &&
ggml_v3_is_contiguous(src1) &&
//src0->type == GGML_V3_TYPE_F32 &&
src1->type == GGML_V3_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
return true;
}
return false;
}
#endif
static void ggml_v3_compute_forward_mul_mat(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
if (ith == 1 && g_imatrix_collect_v3) {
g_imatrix_collect_v3(src0, src1);
}
const enum ggml_v3_type type = src0->type;
const bool src1_cont = ggml_v3_is_contiguous(src1);
ggml_v3_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_v3_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_v3_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
GGML_V3_ASSERT(ne0 == ne01);
GGML_V3_ASSERT(ne1 == ne11);
GGML_V3_ASSERT(ne2 == ne12);
GGML_V3_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_V3_ASSERT(nb00 == ggml_v3_type_size(type));
GGML_V3_ASSERT(nb10 == ggml_v3_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb0 <= nb1);
GGML_V3_ASSERT(nb1 <= nb2);
GGML_V3_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if defined(GGML_USE_CLBLAST)
if (ggml_v3_cl_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_V3_TASK_COMPUTE) {
ggml_v3_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;
}
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_v3_compute_forward_mul_mat_use_blas(dst)) {
if (params->ith != 0) {
return;
}
if (params->type == GGML_V3_TASK_INIT) {
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
// broadcast src0 into src1 across 2nd,3rd dimension
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
if (type != GGML_V3_TYPE_F32) {
float * const wdata = params->wdata;
ggml_v3_to_float_t const to_float = type_traits[type].to_float;
size_t id = 0;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
to_float((const char *) x + i01*nb01, wdata + id, ne00);
id += ne00;
}
assert(id*sizeof(float) <= params->wsize);
x = wdata;
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne1, ne01, ne10,
1.0f, y, ne10,
x, ne00,
0.0f, d, ne01);
}
}
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_v3_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return;
}
#endif
if (params->type == GGML_V3_TASK_INIT) {
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t row_size = ggml_v3_row_size(vec_dot_type, ne10);
assert(params->wsize >= ne11*ne12*ne13*row_size);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
}
}
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_v3_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = ne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
// distribute the thread work across the inner or outer loop based on which one is larger
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
const int64_t ith0 = ith % nth0;
const int64_t ith1 = ith / nth0;
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
const int64_t ir010 = dr0*ith0;
const int64_t ir011 = MIN(ir010 + dr0, nr0);
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
// threads with no work simply yield (not sure if it helps)
if (ir010 >= ir011 || ir110 >= ir111) {
sched_yield();
return;
}
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
// block-tiling attempt
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
// attempt to reduce false-sharing (does not seem to make a difference)
float tmp[16];
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*ne1));
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
// broadcast src0 into src1
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
}
}
// ggml_v3_compute_forward_mul_mat_id
static void ggml_v3_compute_forward_mul_mat_id(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * ids,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
const struct ggml_v3_tensor * src0 = dst->src[2]; // only for GGML_V3_TENSOR_BINARY_OP_LOCALS
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_v3_type type = src0->type;
const bool src1_cont = ggml_v3_is_contiguous(src1);
ggml_v3_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_v3_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_v3_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
GGML_V3_ASSERT(ne0 == ne01);
GGML_V3_ASSERT(ne1 == ne11);
GGML_V3_ASSERT(ne2 == ne12);
GGML_V3_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_V3_ASSERT(nb00 == ggml_v3_type_size(type));
GGML_V3_ASSERT(nb10 == ggml_v3_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb0 <= nb1);
GGML_V3_ASSERT(nb1 <= nb2);
GGML_V3_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
// row groups
const int id = ggml_v3_get_op_params_i32(dst, 0);
const int n_as = ggml_v3_get_op_params_i32(dst, 1);
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
(char *) params->wdata + GGML_V3_PAD(ggml_v3_row_size(vec_dot_type, ggml_v3_nelements(src1)), sizeof(int64_t));
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
if (params->type == GGML_V3_TASK_INIT) {
char * wdata = params->wdata;
if (src1->type != vec_dot_type) {
const size_t row_size = ggml_v3_row_size(vec_dot_type, ne10);
assert(params->wsize >= ne11*ne12*ne13*row_size);
assert(src1->type == GGML_V3_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
}
}
// initialize matrix_row_counts
GGML_V3_ASSERT(wdata == wdata_src1_end);
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
// group rows by src0 matrix
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
GGML_V3_ASSERT(row_id >= 0 && row_id < n_as);
MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
matrix_row_counts[row_id] += 1;
}
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
const struct ggml_v3_tensor * src0_cur = dst->src[cur_a + 2];
if (ith == 1 && g_imatrix_collect_v3) {
g_imatrix_collect_v3(src0_cur, src1);
}
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_v3_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
// distribute the thread work across the inner or outer loop based on which one is larger
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
const int64_t ith0 = ith % nth0;
const int64_t ith1 = ith / nth0;
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
const int64_t ir010 = dr0*ith0;
const int64_t ir011 = MIN(ir010 + dr0, nr0);
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
// threads with no work simply yield (not sure if it helps)
if (ir010 >= ir011 || ir110 >= ir111) {
sched_yield();
continue;
}
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
// block-tiling attempt
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
// attempt to reduce false-sharing (does not seem to make a difference)
float tmp[16];
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
// broadcast src0 into src1
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
}
}
#undef MMID_MATRIX_ROW
}
// ggml_v3_compute_forward_out_prod
static void ggml_v3_compute_forward_out_prod_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
// int64_t t0 = ggml_v3_perf_time_us();
// UNUSED(t0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
GGML_V3_ASSERT(ne0 == ne00);
GGML_V3_ASSERT(ne1 == ne10);
GGML_V3_ASSERT(ne2 == ne02);
GGML_V3_ASSERT(ne02 == ne12);
GGML_V3_ASSERT(ne3 == ne13);
GGML_V3_ASSERT(ne03 == ne13);
// we don't support permuted src0 or src1
GGML_V3_ASSERT(nb00 == sizeof(float));
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 == sizeof(float));
// GGML_V3_ASSERT(nb0 <= nb1);
// GGML_V3_ASSERT(nb1 <= nb2);
// GGML_V3_ASSERT(nb2 <= nb3);
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
// TODO: #if defined(GGML_USE_CUDA) ggml_v3_cuda_out_prod
// TODO: #if defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
bool use_blas = ggml_v3_is_matrix(src0) &&
ggml_v3_is_matrix(src1) &&
ggml_v3_is_contiguous(src0) &&
(ggml_v3_is_contiguous(src1) || ggml_v3_is_transposed(src1));
#endif
if (params->type == GGML_V3_TASK_INIT) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
if (use_blas) {
return;
}
#endif
ggml_v3_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (use_blas) {
if (params->ith != 0) { // All threads other than the first do no work.
return;
}
// Arguments to ggml_v3_compute_forward_out_prod (expressed as major,minor)
// src0: (k,n)
// src1: (k,m)
// dst: (m,n)
//
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
// Also expressed as (major,minor)
// a: (m,k): so src1 transposed
// b: (k,n): so src0
// c: (m,n)
//
// However, if ggml_v3_is_transposed(src1) is true, then
// src1->data already contains a transposed version, so sgemm mustn't
// transpose it further.
int n = src0->ne[0];
int k = src0->ne[1];
int m = src1->ne[0];
int transposeA, lda;
if (!ggml_v3_is_transposed(src1)) {
transposeA = CblasTrans;
lda = m;
} else {
transposeA = CblasNoTrans;
lda = k;
}
float * a = (float *) ((char *) src1->data);
float * b = (float *) ((char *) src0->data);
float * c = (float *) ((char *) dst->data);
cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
return;
}
#endif
// dst[:,:,:,:] = 0
// for i2,i3:
// for i1:
// for i01:
// for i0:
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
// parallelize by last three dimensions
// total rows in dst
const int64_t nr = ne1*ne2*ne3;
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
// block-tiling attempt
const int64_t blck_0 = MAX(GGML_V3_VEC_MAD_UNROLL, 32);
const int64_t blck_1 = 16;
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
const int64_t bir1 = MIN(bir + blck_1, ir1);
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
const int64_t bne01 = MIN(bi01 + blck_0, ne01);
for (int64_t ir = bir; ir < bir1; ++ir) {
// dst indices
const int64_t i3 = ir/(ne2*ne1);
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2;
const int64_t i03 = i3;
//const int64_t i10 = i1;
const int64_t i12 = i2;
const int64_t i13 = i3;
#if GGML_V3_VEC_MAD_UNROLL > 2
const int64_t bne01_unroll = bne01 - (bne01 % GGML_V3_VEC_MAD_UNROLL);
for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_V3_VEC_MAD_UNROLL) {
const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_v3_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
}
for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_v3_vec_mad_f32(ne0, d, s0, *s1);
}
#else
for (int64_t i01 = bi01; i01 < bne01; ++i01) {
const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
ggml_v3_vec_mad_f32(ne0, d, s0, *s1);
}
#endif
}
}
}
//int64_t t1 = ggml_v3_perf_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
//if (t1 - t0 > 10) {
// printf("\n");
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
//}
}
static void ggml_v3_compute_forward_out_prod_q_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
// int64_t t0 = ggml_v3_perf_time_us();
// UNUSED(t0);
GGML_V3_TENSOR_BINARY_OP_LOCALS;
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_v3_type type = src0->type;
ggml_v3_to_float_t const dequantize_row_q = type_traits[type].to_float;
GGML_V3_ASSERT(ne02 == ne12);
GGML_V3_ASSERT(ne03 == ne13);
GGML_V3_ASSERT(ne2 == ne12);
GGML_V3_ASSERT(ne3 == ne13);
// we don't support permuted src0 dim0
GGML_V3_ASSERT(nb00 == ggml_v3_type_size(type));
// dst dim0 cannot be transposed or permuted
GGML_V3_ASSERT(nb0 == sizeof(float));
// GGML_V3_ASSERT(nb0 <= nb1);
// GGML_V3_ASSERT(nb1 <= nb2);
// GGML_V3_ASSERT(nb2 <= nb3);
GGML_V3_ASSERT(ne0 == ne00);
GGML_V3_ASSERT(ne1 == ne10);
GGML_V3_ASSERT(ne2 == ne02);
GGML_V3_ASSERT(ne3 == ne03);
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
// TODO: #if defined(GGML_USE_CUDA) ggml_v3_cuda_out_prod
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (params->type == GGML_V3_TASK_INIT) {
ggml_v3_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// parallelize by last three dimensions
// total rows in dst
const int64_t nr = ne1*ne2*ne3;
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
// dst[:,:,:,:] = 0
// for i2,i3:
// for i1:
// for i01:
// for i0:
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
for (int64_t ir = ir0; ir < ir1; ++ir) {
// dst indices
const int64_t i3 = ir/(ne2*ne1);
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
const int64_t i02 = i2;
const int64_t i03 = i3;
//const int64_t i10 = i1;
const int64_t i12 = i2;
const int64_t i13 = i3;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
const int64_t i11 = i01;
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
dequantize_row_q(s0, wdata, ne0);
ggml_v3_vec_mad_f32(ne0, d, wdata, *s1);
}
}
//int64_t t1 = ggml_v3_perf_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
//if (t1 - t0 > 10) {
// printf("\n");
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
//}
}
static void ggml_v3_compute_forward_out_prod(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_Q4_0:
case GGML_V3_TYPE_Q4_1:
case GGML_V3_TYPE_Q5_0:
case GGML_V3_TYPE_Q5_1:
case GGML_V3_TYPE_Q8_0:
case GGML_V3_TYPE_Q2_K:
case GGML_V3_TYPE_Q3_K:
case GGML_V3_TYPE_Q4_K:
case GGML_V3_TYPE_Q5_K:
case GGML_V3_TYPE_Q6_K:
case GGML_V3_TYPE_IQ2_XXS:
case GGML_V3_TYPE_IQ2_XS:
{
ggml_v3_compute_forward_out_prod_q_f32(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F16:
{
GGML_V3_ASSERT(false); // todo
// ggml_v3_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_out_prod_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_scale
static void ggml_v3_compute_forward_scale_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// scale factor
float v;
memcpy(&v, dst->op_params, sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_v3_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const size_t nb01 = src0->nb[1];
const size_t nb1 = dst->nb[1];
for (int i1 = ir0; i1 < ir1; i1++) {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
}
ggml_v3_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
}
}
static void ggml_v3_compute_forward_scale(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_scale_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_set
static void ggml_v3_compute_forward_set_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst) && ggml_v3_is_contiguous(src0));
// view src0 and dst with these strides and data offset inbytes during set
// nb0 is implicitly element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2];
size_t offset = ((int32_t *) dst->op_params)[3];
bool inplace = (bool) ((int32_t *) dst->op_params)[4];
if (!inplace && (params->type == GGML_V3_TASK_INIT)) {
// memcpy needs to be synchronized across threads to avoid race conditions.
// => do it in INIT phase
memcpy(
((char *) dst->data),
((char *) src0->data),
ggml_v3_nbytes(dst));
}
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(src1);
const int nc = src1->ne[0];
GGML_V3_TENSOR_LOCALS(int64_t, ne1, src1, ne)
GGML_V3_TENSOR_LOCALS(size_t, nb1, src1, nb)
// src0 and dst as viewed during set
const size_t nb0 = ggml_v3_element_size(src0);
const int im0 = (ne10 == 0 ? 0 : ne10-1);
const int im1 = (ne11 == 0 ? 0 : ne11-1);
const int im2 = (ne12 == 0 ? 0 : ne12-1);
const int im3 = (ne13 == 0 ? 0 : ne13-1);
GGML_V3_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_v3_nbytes(dst));
GGML_V3_ASSERT(nb10 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 and dst are viewed with shape of src1 and offset
// => same indices
const int i3 = ir/(ne12*ne11);
const int i2 = (ir - i3*ne12*ne11)/ne11;
const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
ggml_v3_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
}
}
static void ggml_v3_compute_forward_set(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_set_f32(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F16:
case GGML_V3_TYPE_Q4_0:
case GGML_V3_TYPE_Q4_1:
case GGML_V3_TYPE_Q5_0:
case GGML_V3_TYPE_Q5_1:
case GGML_V3_TYPE_Q8_0:
case GGML_V3_TYPE_Q8_1:
case GGML_V3_TYPE_Q2_K:
case GGML_V3_TYPE_Q3_K:
case GGML_V3_TYPE_Q4_K:
case GGML_V3_TYPE_Q5_K:
case GGML_V3_TYPE_Q6_K:
case GGML_V3_TYPE_IQ2_XXS:
case GGML_V3_TYPE_IQ2_XS:
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_cpy
static void ggml_v3_compute_forward_cpy(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
ggml_v3_compute_forward_dup(params, src0, dst);
}
// ggml_v3_compute_forward_cont
static void ggml_v3_compute_forward_cont(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
ggml_v3_compute_forward_dup(params, src0, dst);
}
// ggml_v3_compute_forward_reshape
static void ggml_v3_compute_forward_reshape(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
// NOP
UNUSED(params);
UNUSED(src0);
UNUSED(dst);
}
// ggml_v3_compute_forward_view
static void ggml_v3_compute_forward_view(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0) {
// NOP
UNUSED(params);
UNUSED(src0);
}
// ggml_v3_compute_forward_permute
static void ggml_v3_compute_forward_permute(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0) {
// NOP
UNUSED(params);
UNUSED(src0);
}
// ggml_v3_compute_forward_transpose
static void ggml_v3_compute_forward_transpose(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0) {
// NOP
UNUSED(params);
UNUSED(src0);
}
// ggml_v3_compute_forward_get_rows
static void ggml_v3_compute_forward_get_rows_q(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00;
const int64_t nr = ggml_v3_nelements(src1); GGML_V3_UNUSED(nr);
const enum ggml_v3_type type = src0->type;
ggml_v3_to_float_t const dequantize_row_q = type_traits[type].to_float;
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == ggml_v3_type_size(type));
assert(ggml_v3_nrows(dst) == nr);
// TODO: multi-thread
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
for (int64_t i10 = 0; i10 < ne10; ++i10) {
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
dequantize_row_q(
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
}
}
}
}
static void ggml_v3_compute_forward_get_rows_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00;
const int64_t nr = ggml_v3_nelements(src1); GGML_V3_UNUSED(nr);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(ggml_v3_fp16_t));
assert(ggml_v3_nrows(dst) == nr);
// TODO: multi-thread
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
for (int64_t i10 = 0; i10 < ne10; ++i10) {
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
ggml_v3_fp16_to_fp32_row(
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
}
}
}
}
static void ggml_v3_compute_forward_get_rows_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00;
const int64_t nr = ggml_v3_nelements(src1); GGML_V3_UNUSED(nr);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(float));
assert(ggml_v3_nrows(dst) == nr);
// TODO: multi-thread
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
for (int64_t i10 = 0; i10 < ne10; ++i10) {
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
ggml_v3_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
}
}
}
}
static void ggml_v3_compute_forward_get_rows(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_Q4_0:
case GGML_V3_TYPE_Q4_1:
case GGML_V3_TYPE_Q5_0:
case GGML_V3_TYPE_Q5_1:
case GGML_V3_TYPE_Q8_0:
case GGML_V3_TYPE_Q8_1:
case GGML_V3_TYPE_Q2_K:
case GGML_V3_TYPE_Q3_K:
case GGML_V3_TYPE_Q4_K:
case GGML_V3_TYPE_Q5_K:
case GGML_V3_TYPE_Q6_K:
case GGML_V3_TYPE_IQ2_XXS:
case GGML_V3_TYPE_IQ2_XS:
{
ggml_v3_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_get_rows_f16(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F32:
case GGML_V3_TYPE_I32:
{
ggml_v3_compute_forward_get_rows_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
//static bool first = true;
//printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
//if (first) {
// first = false;
//} else {
// for (int k = 0; k < dst->ne[1]; ++k) {
// for (int j = 0; j < dst->ne[0]/16; ++j) {
// for (int i = 0; i < 16; ++i) {
// printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
// }
// printf("\n");
// }
// printf("\n");
// }
// printf("\n");
// exit(0);
//}
}
// ggml_v3_compute_forward_get_rows_back
static void ggml_v3_compute_forward_get_rows_back_f32_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(params->ith == 0);
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst));
// ggml_v3_compute_forward_dup_same_cont(params, opt0, dst);
if (params->type == GGML_V3_TASK_INIT) {
memset(dst->data, 0, ggml_v3_nbytes(dst));
}
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int nc = src0->ne[0];
const int nr = ggml_v3_nelements(src1);
GGML_V3_ASSERT( dst->ne[0] == nc);
GGML_V3_ASSERT(src0->nb[0] == sizeof(ggml_v3_fp16_t));
for (int i = 0; i < nr; ++i) {
const int r = ((int32_t *) src1->data)[i];
for (int j = 0; j < nc; ++j) {
ggml_v3_fp16_t v = ((ggml_v3_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_V3_FP16_TO_FP32(v);
}
}
}
static void ggml_v3_compute_forward_get_rows_back_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(params->ith == 0);
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst));
// ggml_v3_compute_forward_dup_same_cont(params, opt0, dst);
if (params->type == GGML_V3_TASK_INIT) {
memset(dst->data, 0, ggml_v3_nbytes(dst));
}
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int nc = src0->ne[0];
const int nr = ggml_v3_nelements(src1);
GGML_V3_ASSERT( dst->ne[0] == nc);
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
for (int i = 0; i < nr; ++i) {
const int r = ((int32_t *) src1->data)[i];
ggml_v3_vec_add_f32(nc,
(float *) ((char *) dst->data + r*dst->nb[1]),
(float *) ((char *) dst->data + r*dst->nb[1]),
(float *) ((char *) src0->data + i*src0->nb[1]));
}
}
static void ggml_v3_compute_forward_get_rows_back(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_get_rows_back_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
//static bool first = true;
//printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
//if (first) {
// first = false;
//} else {
// for (int k = 0; k < dst->ne[1]; ++k) {
// for (int j = 0; j < dst->ne[0]/16; ++j) {
// for (int i = 0; i < 16; ++i) {
// printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
// }
// printf("\n");
// }
// printf("\n");
// }
// printf("\n");
// exit(0);
//}
}
// ggml_v3_compute_forward_diag
static void ggml_v3_compute_forward_diag_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// TODO: handle transposed/permuted matrices
GGML_V3_TENSOR_UNARY_OP_LOCALS
GGML_V3_ASSERT(ne00 == ne0);
GGML_V3_ASSERT(ne00 == ne1);
GGML_V3_ASSERT(ne01 == 1);
GGML_V3_ASSERT(ne02 == ne2);
GGML_V3_ASSERT(ne03 == ne3);
GGML_V3_ASSERT(nb00 == sizeof(float));
GGML_V3_ASSERT(nb0 == sizeof(float));
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = 0; i2 < ne2; i2++) {
for (int i1 = 0; i1 < ne1; i1++) {
float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
for (int i0 = 0; i0 < i1; i0++) {
d[i0] = 0;
}
d[i1] = s[i1];
for (int i0 = i1+1; i0 < ne0; i0++) {
d[i0] = 0;
}
}
}
}
}
static void ggml_v3_compute_forward_diag(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_diag_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_diag_mask_inf
static void ggml_v3_compute_forward_diag_mask_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst,
const float value) {
const int ith = params->ith;
const int nth = params->nth;
const int n_past = ((int32_t *) dst->op_params)[0];
const bool inplace = src0->data == dst->data;
GGML_V3_ASSERT(n_past >= 0);
if (!inplace && (params->type == GGML_V3_TASK_INIT)) {
// memcpy needs to be synchronized across threads to avoid race conditions.
// => do it in INIT phase
GGML_V3_ASSERT(ggml_v3_nelements(dst) == ggml_v3_nelements(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst) && ggml_v3_is_contiguous(src0));
memcpy(
((char *) dst->data),
((char *) src0->data),
ggml_v3_nbytes(dst));
}
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// TODO: handle transposed/permuted matrices
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
const int nr = src0->ne[1];
const int nz = n/nr;
GGML_V3_ASSERT( dst->nb[0] == sizeof(float));
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
for (int k = 0; k < nz; k++) {
for (int j = ith; j < nr; j += nth) {
for (int i = n_past; i < nc; i++) {
if (i > n_past + j) {
*(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
}
}
}
}
}
static void ggml_v3_compute_forward_diag_mask_inf(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_diag_mask_f32(params, src0, dst, -INFINITY);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
static void ggml_v3_compute_forward_diag_mask_zero(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_diag_mask_f32(params, src0, dst, 0);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_soft_max
static void ggml_v3_compute_forward_soft_max_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
assert(ggml_v3_is_contiguous(dst));
assert(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
float scale = 1.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
const int nth = params->nth;
const int64_t ne11 = src1 ? src1->ne[1] : 1;
const int nc = src0->ne[0];
const int nr = ggml_v3_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
for (int i1 = ir0; i1 < ir1; i1++) {
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
// broadcast the mask across rows
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
ggml_v3_vec_cpy_f32 (nc, wp, sp);
ggml_v3_vec_scale_f32(nc, wp, scale);
if (mp) {
ggml_v3_vec_acc_f32(nc, wp, mp);
}
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(wp[i]));
}
#endif
float max = -INFINITY;
ggml_v3_vec_max_f32(nc, &max, wp);
ggml_v3_float sum = 0.0;
uint16_t scvt;
for (int i = 0; i < nc; i++) {
if (wp[i] == -INFINITY) {
dp[i] = 0.0f;
} else {
// const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
ggml_v3_fp16_t s = GGML_V3_FP32_TO_FP16(wp[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_V3_FP16_TO_FP32(ggml_v3_table_exp_f16[scvt]);
sum += (ggml_v3_float)val;
dp[i] = val;
}
}
assert(sum > 0.0);
sum = 1.0/sum;
ggml_v3_vec_scale_f32(nc, dp, sum);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
assert(!isnan(dp[i]));
assert(!isinf(dp[i]));
}
#endif
}
}
static void ggml_v3_compute_forward_soft_max(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_soft_max_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_soft_max_back
static void ggml_v3_compute_forward_soft_max_back_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous(src1));
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src1, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// TODO: handle transposed/permuted matrices
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_v3_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(dy[i]));
assert(!isnan(y[i]));
}
#endif
// Jii = yi - yi*yi
// Jij = -yi*yj
// J = diag(y)-y.T*y
// dx = J * dy
// dxk = sum_i(Jki * dyi)
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
// dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
// dxk = -yk * dot(y, dy) + yk*dyk
// dxk = yk * (- dot(y, dy) + dyk)
// dxk = yk * (dyk - dot(y, dy))
//
// post-order:
// dot_y_dy := dot(y, dy)
// dx := dy
// dx := dx - dot_y_dy
// dx := dx * y
// linear runtime, no additional memory
float dot_y_dy = 0;
ggml_v3_vec_dot_f32 (nc, &dot_y_dy, y, dy);
ggml_v3_vec_cpy_f32 (nc, dx, dy);
ggml_v3_vec_acc1_f32(nc, dx, -dot_y_dy);
ggml_v3_vec_mul_f32 (nc, dx, dx, y);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
assert(!isnan(dx[i]));
assert(!isinf(dx[i]));
}
#endif
}
}
static void ggml_v3_compute_forward_soft_max_back(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_soft_max_back_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_alibi
static void ggml_v3_compute_forward_alibi_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
//const int64_t ne3 = src0->ne[3]; // 1 -> bsz
const int64_t n = ggml_v3_nrows(src0);
const int64_t ne2_ne3 = n/ne1; // ne2*ne3
const size_t nb0 = src0->nb[0];
const size_t nb1 = src0->nb[1];
const size_t nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
for (int64_t i = 0; i < ne0; i++) {
for (int64_t j = 0; j < ne1; j++) {
for (int64_t k = 0; k < ne2_ne3; k++) {
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
// TODO: k*nb2 or k*nb3
float m_k;
if (k < n_heads_log2_floor) {
m_k = powf(m0, k + 1);
} else {
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
}
pdst[0] = i * m_k + src[0];
}
}
}
}
static void ggml_v3_compute_forward_alibi_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_v3_nrows(src0);
const int ne2_ne3 = n/ne1; // ne2*ne3
const int nb0 = src0->nb[0];
const int nb1 = src0->nb[1];
const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
GGML_V3_ASSERT(nb0 == sizeof(ggml_v3_fp16_t));
//GGML_V3_ASSERT(ne1 + n_past == ne0); (void) n_past;
GGML_V3_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
for (int i = 0; i < ne0; i++) {
for (int j = 0; j < ne1; j++) {
for (int k = 0; k < ne2_ne3; k++) {
ggml_v3_fp16_t * const src = (ggml_v3_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
// TODO: k*nb2 or k*nb3
float m_k;
if (k < n_heads_log2_floor) {
m_k = powf(m0, k + 1);
} else {
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
}
// we return F32
pdst[0] = i * m_k + GGML_V3_FP16_TO_FP32(src[0]);
}
}
}
}
static void ggml_v3_compute_forward_alibi(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_alibi_f16(params, src0, dst);
} break;
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_alibi_f32(params, src0, dst);
} break;
case GGML_V3_TYPE_Q4_0:
case GGML_V3_TYPE_Q4_1:
case GGML_V3_TYPE_Q5_0:
case GGML_V3_TYPE_Q5_1:
case GGML_V3_TYPE_Q8_0:
case GGML_V3_TYPE_Q8_1:
case GGML_V3_TYPE_Q2_K:
case GGML_V3_TYPE_Q3_K:
case GGML_V3_TYPE_Q4_K:
case GGML_V3_TYPE_Q5_K:
case GGML_V3_TYPE_Q6_K:
case GGML_V3_TYPE_IQ2_XXS:
case GGML_V3_TYPE_IQ2_XS:
case GGML_V3_TYPE_Q8_K:
case GGML_V3_TYPE_I8:
case GGML_V3_TYPE_I16:
case GGML_V3_TYPE_I32:
case GGML_V3_TYPE_COUNT:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_clamp
static void ggml_v3_compute_forward_clamp_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
float min;
float max;
memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
const size_t nb0 = dst->nb[0];
const size_t nb1 = dst->nb[1];
GGML_V3_ASSERT( nb0 == sizeof(float));
GGML_V3_ASSERT(nb00 == sizeof(float));
for (int j = ith; j < n; j += nth) {
float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
for (int i = 0; i < nc; i++) {
dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
}
}
}
static void ggml_v3_compute_forward_clamp(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_clamp_f32(params, src0, dst);
} break;
case GGML_V3_TYPE_F16:
case GGML_V3_TYPE_Q4_0:
case GGML_V3_TYPE_Q4_1:
case GGML_V3_TYPE_Q5_0:
case GGML_V3_TYPE_Q5_1:
case GGML_V3_TYPE_Q8_0:
case GGML_V3_TYPE_Q8_1:
case GGML_V3_TYPE_Q2_K:
case GGML_V3_TYPE_Q3_K:
case GGML_V3_TYPE_Q4_K:
case GGML_V3_TYPE_Q5_K:
case GGML_V3_TYPE_Q6_K:
case GGML_V3_TYPE_IQ2_XXS:
case GGML_V3_TYPE_IQ2_XS:
case GGML_V3_TYPE_Q8_K:
case GGML_V3_TYPE_I8:
case GGML_V3_TYPE_I16:
case GGML_V3_TYPE_I32:
case GGML_V3_TYPE_COUNT:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_rope
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
return 1 - MIN(1, MAX(0, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ggml_v3_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}
void ggml_v3_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = MAX(0, floorf(ggml_v3_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = MIN(n_dims - 1, ceilf(ggml_v3_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
}
static void ggml_v3_compute_forward_rope_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst,
const bool forward) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
// these two only relevant for xPos RoPE:
float xpos_base;
bool xpos_down;
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
GGML_V3_TENSOR_UNARY_OP_LOCALS
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
GGML_V3_ASSERT(nb00 == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(dst);
GGML_V3_ASSERT(n_dims <= ne0);
GGML_V3_ASSERT(n_dims % 2 == 0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// row index used to determine which thread to use
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2];
ggml_v3_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
float theta_base = (float)p;
if (is_glm) {
theta_base = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base) * sin_sign;
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign;
theta_base *= theta_scale;
block_theta *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
const float x2 = src[n_dims];
const float x3 = src[n_dims/2*3];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
);
sin_theta *= sin_sign;
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
theta_base *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
}
}
}
}
static void ggml_v3_compute_forward_rope_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst,
const bool forward) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
GGML_V3_TENSOR_UNARY_OP_LOCALS
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
GGML_V3_ASSERT(nb0 == sizeof(ggml_v3_fp16_t));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_v3_nrows(dst);
GGML_V3_ASSERT(n_dims <= ne0);
GGML_V3_ASSERT(n_dims % 2 == 0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// row index used to determine which thread to use
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2];
ggml_v3_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
float theta_base = (float)p;
if (is_glm) {
theta_base = MIN(p, n_ctx - 2);
float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base) * sin_sign;
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign;
theta_base *= theta_scale;
block_theta *= theta_scale;
const ggml_v3_fp16_t * const src = (ggml_v3_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_v3_fp16_t * dst_data = (ggml_v3_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_V3_FP16_TO_FP32(src[0]);
const float x1 = GGML_V3_FP16_TO_FP32(src[n_dims/2]);
const float x2 = GGML_V3_FP16_TO_FP32(src[n_dims]);
const float x3 = GGML_V3_FP16_TO_FP32(src[n_dims/2*3]);
dst_data[0] = GGML_V3_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_V3_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
dst_data[n_dims] = GGML_V3_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
dst_data[n_dims/2*3] = GGML_V3_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
);
sin_theta *= sin_sign;
theta_base *= theta_scale;
const ggml_v3_fp16_t * const src = (ggml_v3_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_v3_fp16_t * dst_data = (ggml_v3_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_V3_FP16_TO_FP32(src[0]);
const float x1 = GGML_V3_FP16_TO_FP32(src[1]);
dst_data[0] = GGML_V3_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_V3_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
float cos_theta, sin_theta;
rope_yarn(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
const ggml_v3_fp16_t * const src = (ggml_v3_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_v3_fp16_t * dst_data = (ggml_v3_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_V3_FP16_TO_FP32(src[0]);
const float x1 = GGML_V3_FP16_TO_FP32(src[n_dims/2]);
dst_data[0] = GGML_V3_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_V3_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} else {
const int64_t i0 = ic;
const ggml_v3_fp16_t * const src = (ggml_v3_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_v3_fp16_t * dst_data = (ggml_v3_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
}
}
}
}
static void ggml_v3_compute_forward_rope(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_rope_f16(params, src0, src1, dst, true);
} break;
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_rope_f32(params, src0, src1, dst, true);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_rope_back
static void ggml_v3_compute_forward_rope_back(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_rope_f16(params, src0, src1, dst, false);
} break;
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_rope_f32(params, src0, src1, dst, false);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_conv_transpose_1d
static void ggml_v3_compute_forward_conv_transpose_1d_f16_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(src0->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT( dst->type == GGML_V3_TYPE_F32);
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const int nk = ne00*ne01*ne02;
GGML_V3_ASSERT(nb00 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_V3_TASK_INIT) {
memset(params->wdata, 0, params->wsize);
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
{
ggml_v3_fp16_t * const wdata = (ggml_v3_fp16_t *) params->wdata + 0;
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const ggml_v3_fp16_t * const src = (ggml_v3_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
ggml_v3_fp16_t * dst_data = wdata + i01*ne00*ne02;
for (int64_t i00 = 0; i00 < ne00; i00++) {
dst_data[i00*ne02 + i02] = src[i00];
}
}
}
}
// permute source data (src1) from (L x Cin) to (Cin x L)
{
ggml_v3_fp16_t * const wdata = (ggml_v3_fp16_t *) params->wdata + nk;
ggml_v3_fp16_t * dst_data = wdata;
for (int64_t i11 = 0; i11 < ne11; i11++) {
const float * const src = (float *)((char *) src1->data + i11*nb11);
for (int64_t i10 = 0; i10 < ne10; i10++) {
dst_data[i10*ne11 + i11] = GGML_V3_FP32_TO_FP16(src[i10]);
}
}
}
// need to zero dst since we are accumulating into it
memset(dst->data, 0, ggml_v3_nbytes(dst));
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
// total rows in dst
const int nr = ne1;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
ggml_v3_fp16_t * const wdata = (ggml_v3_fp16_t *) params->wdata + 0;
ggml_v3_fp16_t * const wdata_src = wdata + nk;
for (int i1 = ir0; i1 < ir1; i1++) {
float * dst_data = (float *)((char *) dst->data + i1*nb1);
ggml_v3_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
for (int i10 = 0; i10 < ne10; i10++) {
const int i1n = i10*ne11;
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
ggml_v3_vec_dot_f16(ne02, &v,
(ggml_v3_fp16_t *) wdata_src + i1n,
(ggml_v3_fp16_t *) wdata_kernel + i00*ne02);
dst_data[i10*s0 + i00] += v;
}
}
}
}
static void ggml_v3_compute_forward_conv_transpose_1d_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(src0->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT( dst->type == GGML_V3_TYPE_F32);
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const int nk = ne00*ne01*ne02;
GGML_V3_ASSERT(nb00 == sizeof(float));
GGML_V3_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_V3_TASK_INIT) {
memset(params->wdata, 0, params->wsize);
// prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
{
float * const wdata = (float *) params->wdata + 0;
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
float * dst_data = wdata + i01*ne00*ne02;
for (int64_t i00 = 0; i00 < ne00; i00++) {
dst_data[i00*ne02 + i02] = src[i00];
}
}
}
}
// prepare source data (src1)
{
float * const wdata = (float *) params->wdata + nk;
float * dst_data = wdata;
for (int64_t i11 = 0; i11 < ne11; i11++) {
const float * const src = (float *)((char *) src1->data + i11*nb11);
for (int64_t i10 = 0; i10 < ne10; i10++) {
dst_data[i10*ne11 + i11] = src[i10];
}
}
}
// need to zero dst since we are accumulating into it
memset(dst->data, 0, ggml_v3_nbytes(dst));
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
// total rows in dst
const int nr = ne1;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float * const wdata = (float *) params->wdata + 0;
float * const wdata_src = wdata + nk;
for (int i1 = ir0; i1 < ir1; i1++) {
float * dst_data = (float *)((char *) dst->data + i1*nb1);
float * wdata_kernel = wdata + i1*ne02*ne00;
for (int i10 = 0; i10 < ne10; i10++) {
const int i1n = i10*ne11;
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
ggml_v3_vec_dot_f32(ne02, &v,
wdata_src + i1n,
wdata_kernel + i00*ne02);
dst_data[i10*s0 + i00] += v;
}
}
}
}
static void ggml_v3_compute_forward_conv_transpose_1d(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// src0: kernel [OC, IC, KH, KW]
// src1: image [N, IC, IH, IW]
// dst: result [N, OH, OW, IC*KH*KW]
static void ggml_v3_compute_forward_im2col_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(src0->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT( dst->type == GGML_V3_TYPE_F16);
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = is_2D ? ne13 : ne12;
const int64_t IC = is_2D ? ne12 : ne11;
const int64_t IH = is_2D ? ne11 : 1;
const int64_t IW = ne10;
const int64_t KH = is_2D ? ne01 : 1;
const int64_t KW = ne00;
const int64_t OH = is_2D ? ne2 : 1;
const int64_t OW = ne1;
int ofs0 = is_2D ? nb13 : nb12;
int ofs1 = is_2D ? nb12 : nb11;
GGML_V3_ASSERT(nb00 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_V3_TASK_INIT) {
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
{
ggml_v3_fp16_t * const wdata = (ggml_v3_fp16_t *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
ggml_v3_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
} else {
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_V3_FP32_TO_FP16(src_data[iih*IW + iiw]);
}
}
}
}
}
}
}
}
}
static void ggml_v3_compute_forward_im2col(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_im2col_f16(params, src0, src1, dst);
} break;
case GGML_V3_TYPE_F32:
{
GGML_V3_ASSERT(false);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_conv_transpose_2d
static void ggml_v3_compute_forward_conv_transpose_2d(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(src0->type == GGML_V3_TYPE_F16);
GGML_V3_ASSERT(src1->type == GGML_V3_TYPE_F32);
GGML_V3_ASSERT( dst->type == GGML_V3_TYPE_F32);
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const int nk = ne00*ne01*ne02*ne03;
GGML_V3_ASSERT(nb00 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_V3_TASK_INIT) {
memset(params->wdata, 0, params->wsize);
// permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
{
ggml_v3_fp16_t * const wdata = (ggml_v3_fp16_t *) params->wdata + 0;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const ggml_v3_fp16_t * const src = (ggml_v3_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
ggml_v3_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
for (int64_t i01 = 0; i01 < ne01; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
}
}
}
}
}
// permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
{
ggml_v3_fp16_t * const wdata = (ggml_v3_fp16_t *) params->wdata + nk;
for (int i12 = 0; i12 < ne12; i12++) {
for (int i11 = 0; i11 < ne11; i11++) {
const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
ggml_v3_fp16_t * dst_data = wdata + i11*ne10*ne12;
for (int i10 = 0; i10 < ne10; i10++) {
dst_data[i10*ne12 + i12] = GGML_V3_FP32_TO_FP16(src[i10]);
}
}
}
}
memset(dst->data, 0, ggml_v3_nbytes(dst));
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int32_t stride = ggml_v3_get_op_params_i32(dst, 0);
// total patches in dst
const int np = ne2;
// patches per thread
const int dp = (np + nth - 1)/nth;
// patch range for this thread
const int ip0 = dp*ith;
const int ip1 = MIN(ip0 + dp, np);
ggml_v3_fp16_t * const wdata = (ggml_v3_fp16_t *) params->wdata + 0;
ggml_v3_fp16_t * const wdata_src = wdata + nk;
for (int i2 = ip0; i2 < ip1; i2++) { // Cout
float * dst_data = (float *)((char *) dst->data + i2*nb2);
ggml_v3_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
for (int i11 = 0; i11 < ne11; i11++) {
for (int i10 = 0; i10 < ne10; i10++) {
const int i1n = i11*ne10*ne12 + i10*ne12;
for (int i01 = 0; i01 < ne01; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
ggml_v3_vec_dot_f16(ne03, &v,
wdata_src + i1n,
wdata_kernel + i01*ne00*ne03 + i00*ne03);
dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
}
}
}
}
}
}
// ggml_v3_compute_forward_pool_1d_sk_p0
static void ggml_v3_compute_forward_pool_1d_sk_p0(
const struct ggml_v3_compute_params * params,
const enum ggml_v3_op_pool op,
const struct ggml_v3_tensor * src,
const int k,
struct ggml_v3_tensor * dst) {
assert(src->type == GGML_V3_TYPE_F32);
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const char * cdata = (const char *)src->data;
const char * const data_end = cdata + ggml_v3_nbytes(src);
float * drow = (float *)dst->data;
const int64_t rs = dst->ne[0];
while (cdata < data_end) {
const float * const srow = (const float *)cdata;
int j = 0;
for (int64_t i = 0; i < rs; ++i) {
switch (op) {
case GGML_V3_OP_POOL_AVG: drow[i] = 0; break;
case GGML_V3_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
case GGML_V3_OP_POOL_COUNT: GGML_V3_ASSERT(false); break;
}
for (int ki = 0; ki < k; ++ki) {
switch (op) {
case GGML_V3_OP_POOL_AVG: drow[i] += srow[j]; break;
case GGML_V3_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
case GGML_V3_OP_POOL_COUNT: GGML_V3_ASSERT(false); break;
}
++j;
}
switch (op) {
case GGML_V3_OP_POOL_AVG: drow[i] /= k; break;
case GGML_V3_OP_POOL_MAX: break;
case GGML_V3_OP_POOL_COUNT: GGML_V3_ASSERT(false); break;
}
}
cdata += src->nb[1];
drow += rs;
}
}
// ggml_v3_compute_forward_pool_1d
static void ggml_v3_compute_forward_pool_1d(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_v3_op_pool op = opts[0];
const int k0 = opts[1];
const int s0 = opts[2];
const int p0 = opts[3];
GGML_V3_ASSERT(p0 == 0); // padding not supported
GGML_V3_ASSERT(k0 == s0); // only s = k supported
ggml_v3_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
}
// ggml_v3_compute_forward_pool_2d
static void ggml_v3_compute_forward_pool_2d(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src,
struct ggml_v3_tensor * dst) {
assert(src->type == GGML_V3_TYPE_F32);
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_v3_op_pool op = opts[0];
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
const char * cdata = (const char*)src->data;
const char * const data_end = cdata + ggml_v3_nbytes(src);
const int64_t px = dst->ne[0];
const int64_t py = dst->ne[1];
const int64_t pa = px * py;
float * dplane = (float *)dst->data;
const int ka = k0 * k1;
const int offset0 = -p0;
const int offset1 = -p1;
while (cdata < data_end) {
for (int oy = 0; oy < py; ++oy) {
float * const drow = dplane + oy * px;
for (int ox = 0; ox < px; ++ox) {
float * const out = drow + ox;
switch (op) {
case GGML_V3_OP_POOL_AVG: *out = 0; break;
case GGML_V3_OP_POOL_MAX: *out = -FLT_MAX; break;
case GGML_V3_OP_POOL_COUNT: GGML_V3_ASSERT(false); break;
}
const int ix = offset0 + ox * s0;
const int iy = offset1 + oy * s1;
for (int ky = 0; ky < k1; ++ky) {
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
for (int kx = 0; kx < k0; ++kx) {
int j = ix + kx;
if (j < 0 || j >= src->ne[0]) continue;
switch (op) {
case GGML_V3_OP_POOL_AVG: *out += srow[j]; break;
case GGML_V3_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
case GGML_V3_OP_POOL_COUNT: GGML_V3_ASSERT(false); break;
}
}
}
switch (op) {
case GGML_V3_OP_POOL_AVG: *out /= ka; break;
case GGML_V3_OP_POOL_MAX: break;
case GGML_V3_OP_POOL_COUNT: GGML_V3_ASSERT(false); break;
}
}
}
cdata += src->nb[2];
dplane += pa;
}
}
// ggml_v3_compute_forward_upscale
static void ggml_v3_compute_forward_upscale_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_V3_TENSOR_UNARY_OP_LOCALS
const int scale_factor = dst->op_params[0];
// TODO: optimize
for (int64_t i3 = 0; i3 < ne3; i3++) {
const int64_t i03 = i3;
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
const int64_t i02 = i2;
for (int64_t i1 = 0; i1 < ne1; i1++) {
const int64_t i01 = i1 / scale_factor;
for (int64_t i0 = 0; i0 < ne0; i0++) {
const int64_t i00 = i0 / scale_factor;
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
*y = *x;
}
}
}
}
}
static void ggml_v3_compute_forward_upscale(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_upscale_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_pad
static void ggml_v3_compute_forward_pad_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_ASSERT(src0->nb[0] == sizeof(float));
GGML_V3_ASSERT( dst->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
GGML_V3_TENSOR_UNARY_OP_LOCALS
float * dst_ptr = (float *) dst->data;
// TODO: optimize
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;
}
}
}
}
}
}
static void ggml_v3_compute_forward_pad(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_pad_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_argsort
static void ggml_v3_compute_forward_argsort_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_UNARY_OP_LOCALS
GGML_V3_ASSERT(nb0 == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_v3_nrows(src0);
enum ggml_v3_sort_order order = (enum ggml_v3_sort_order) ggml_v3_get_op_params_i32(dst, 0);
for (int64_t i = ith; i < nr; i += nth) {
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
const float * src_data = (float *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne0; j++) {
dst_data[j] = j;
}
// C doesn't have a functional sort, so we do a bubble sort instead
for (int64_t j = 0; j < ne0; j++) {
for (int64_t k = j + 1; k < ne0; k++) {
if ((order == GGML_V3_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
(order == GGML_V3_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
int32_t tmp = dst_data[j];
dst_data[j] = dst_data[k];
dst_data[k] = tmp;
}
}
}
}
}
static void ggml_v3_compute_forward_argsort(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_argsort_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_flash_attn
static void ggml_v3_compute_forward_flash_attn_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * q,
const struct ggml_v3_tensor * k,
const struct ggml_v3_tensor * v,
const bool masked,
struct ggml_v3_tensor * dst) {
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_V3_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_V3_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_V3_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_V3_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
const int64_t D = neq0;
const int64_t N = neq1;
const int64_t P = nek1 - N;
const int64_t M = P + N;
const int Mup = ggml_v3_up(M, GGML_V3_SOFT_MAX_UNROLL);
GGML_V3_ASSERT(ne0 == D);
GGML_V3_ASSERT(ne1 == N);
GGML_V3_ASSERT(P >= 0);
GGML_V3_ASSERT(nbq0 == sizeof(float));
GGML_V3_ASSERT(nbk0 == sizeof(float));
GGML_V3_ASSERT(nbv0 == sizeof(float));
GGML_V3_ASSERT(neq0 == D);
GGML_V3_ASSERT(nek0 == D);
GGML_V3_ASSERT(nev1 == D);
GGML_V3_ASSERT(neq1 == N);
GGML_V3_ASSERT(nek1 == N + P);
GGML_V3_ASSERT(nev1 == D);
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb0 <= nb1);
GGML_V3_ASSERT(nb1 <= nb2);
GGML_V3_ASSERT(nb2 <= nb3);
if (params->type == GGML_V3_TASK_INIT) {
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// parallelize by q rows using ggml_v3_vec_dot_f32
// total rows in q
const int nr = neq1*neq2*neq3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const float scale = 1.0f/sqrtf(D);
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
for (int i = M; i < Mup; ++i) {
S[i] = -INFINITY;
}
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
for (int64_t ic = 0; ic < masked_begin; ++ic) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
const int i1 = ik1;
ggml_v3_vec_dot_f32(neq0,
S + i1,
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
// scale
ggml_v3_vec_scale_f32(masked_begin, S, scale);
for (int64_t i = masked_begin; i < M; i++) {
S[i] = -INFINITY;
}
// softmax
// exclude known -INF S[..] values from max and loop
// dont forget to set their SW values to zero
{
float max = -INFINITY;
ggml_v3_vec_max_f32(masked_begin, &max, S);
ggml_v3_float sum = 0.0;
{
#ifdef GGML_V3_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(S, 1, &max, S, 1, Mup);
vvexpf(S, S, &Mup);
ggml_v3_vec_sum_f32(Mup, &sum, S);
#else
uint16_t scvt[GGML_V3_SOFT_MAX_UNROLL]; UNUSED(scvt);
ggml_v3_float sump[GGML_V3_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_V3_SOFT_MAX_UNROLL) {
if (i >= masked_begin) {
break;
}
float * SS = S + i;
for (int j = 0; j < GGML_V3_SOFT_MAX_UNROLL; ++j) {
if (i + j >= masked_begin) {
break;
} else if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
#ifndef GGML_V3_FLASH_ATTN_EXP_FP16
const float val = expf(SS[j] - max);
#else
ggml_v3_fp16_t s = GGML_V3_FP32_TO_FP16(SS[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_V3_FP16_TO_FP32(ggml_v3_table_exp_f16[scvt[j]]);
#endif
sump[j] += (ggml_v3_float)val;
SS[j] = val;
}
}
}
for (int i = 0; i < GGML_V3_SOFT_MAX_UNROLL; i++) {
sum += sump[i];
}
#endif
}
assert(sum > 0.0);
sum = 1.0/sum;
ggml_v3_vec_scale_f32(masked_begin, S, sum);
#ifndef NDEBUG
for (int i = 0; i < masked_begin; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
#endif
}
for (int64_t ic = 0; ic < nev1; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// v indices
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
ggml_v3_vec_dot_f32(masked_begin,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S);
}
}
}
static void ggml_v3_compute_forward_flash_attn_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * q,
const struct ggml_v3_tensor * k,
const struct ggml_v3_tensor * v,
const bool masked,
struct ggml_v3_tensor * dst) {
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_V3_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_V3_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_V3_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_V3_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
const int64_t D = neq0;
const int64_t N = neq1;
const int64_t P = nek1 - N;
const int64_t M = P + N;
const int Mup = ggml_v3_up(M, GGML_V3_SOFT_MAX_UNROLL);
GGML_V3_ASSERT(ne0 == D);
GGML_V3_ASSERT(ne1 == N);
GGML_V3_ASSERT(P >= 0);
GGML_V3_ASSERT(nbq0 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nbk0 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nbv0 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(neq0 == D);
GGML_V3_ASSERT(nek0 == D);
GGML_V3_ASSERT(nev1 == D);
GGML_V3_ASSERT(neq1 == N);
GGML_V3_ASSERT(nek1 == N + P);
GGML_V3_ASSERT(nev1 == D);
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb0 <= nb1);
GGML_V3_ASSERT(nb1 <= nb2);
GGML_V3_ASSERT(nb2 <= nb3);
if (params->type == GGML_V3_TASK_INIT) {
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// parallelize by q rows using ggml_v3_vec_dot_f32
// total rows in q
const int nr = neq1*neq2*neq3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const float scale = 1.0f/sqrtf(D);
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
for (int i = M; i < Mup; ++i) {
S[i] = -INFINITY;
}
if (GGML_V3_VEC_DOT_UNROLL > 2 || nek1 % GGML_V3_VEC_DOT_UNROLL != 0) {
for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
const int i1 = ik1;
ggml_v3_vec_dot_f16(neq0,
S + i1,
(ggml_v3_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_v3_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
} else {
for (int64_t ic = 0; ic < nek1; ic += GGML_V3_VEC_DOT_UNROLL) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
const int i1 = ik1;
ggml_v3_vec_dot_f16_unroll(neq0, nbk1,
S + i1,
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_v3_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
}
// scale
ggml_v3_vec_scale_f32(nek1, S, scale);
if (masked) {
for (int64_t i = P; i < M; i++) {
if (i > P + iq1) {
S[i] = -INFINITY;
}
}
}
// softmax
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
// dont forget to set their S values to zero
{
float max = -INFINITY;
ggml_v3_vec_max_f32(M, &max, S);
ggml_v3_float sum = 0.0;
{
#ifdef GGML_V3_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(S, 1, &max, S, 1, Mup);
vvexpf(S, S, &Mup);
ggml_v3_vec_sum_f32(Mup, &sum, S);
#else
uint16_t scvt[GGML_V3_SOFT_MAX_UNROLL];
ggml_v3_float sump[GGML_V3_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_V3_SOFT_MAX_UNROLL) {
float * SS = S + i;
for (int j = 0; j < GGML_V3_SOFT_MAX_UNROLL; ++j) {
if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
ggml_v3_fp16_t s = GGML_V3_FP32_TO_FP16(SS[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_V3_FP16_TO_FP32(ggml_v3_table_exp_f16[scvt[j]]);
sump[j] += (ggml_v3_float)val;
SS[j] = val;
}
}
}
for (int i = 0; i < GGML_V3_SOFT_MAX_UNROLL; i++) {
sum += sump[i];
}
#endif
}
assert(sum > 0.0);
sum = 1.0/sum;
ggml_v3_vec_scale_f32(M, S, sum);
#ifndef NDEBUG
for (int i = 0; i < M; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
#endif
}
ggml_v3_fp16_t * S16 = (ggml_v3_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
for (int64_t i = 0; i < M; i++) {
S16[i] = GGML_V3_FP32_TO_FP16(S[i]);
}
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
if (GGML_V3_VEC_DOT_UNROLL == 1 || (nev1 % GGML_V3_VEC_DOT_UNROLL != 0)) {
for (int64_t ic = 0; ic < nev1; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// v indices
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
ggml_v3_vec_dot_f16(nev0,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(ggml_v3_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
} else {
for (int64_t ic = 0; ic < nev1; ic += GGML_V3_VEC_DOT_UNROLL) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// v indices
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
ggml_v3_vec_dot_f16_unroll(nev0, nbv1,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
}
}
}
static void ggml_v3_compute_forward_flash_attn(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * q,
const struct ggml_v3_tensor * k,
const struct ggml_v3_tensor * v,
const bool masked,
struct ggml_v3_tensor * dst) {
switch (q->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_flash_attn_f16(params, q, k, v, masked, dst);
} break;
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_flash_ff
static void ggml_v3_compute_forward_flash_ff_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * a, // F16
const struct ggml_v3_tensor * b0, // F16 fc_w
const struct ggml_v3_tensor * b1, // F32 fc_b
const struct ggml_v3_tensor * c0, // F16 proj_w
const struct ggml_v3_tensor * c1, // F32 proj_b
struct ggml_v3_tensor * dst) {
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_LOCALS(int64_t, nea, a, ne)
GGML_V3_TENSOR_LOCALS(size_t, nba, a, nb)
GGML_V3_TENSOR_LOCALS(int64_t, neb0, b0, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbb0, b0, nb)
GGML_V3_TENSOR_LOCALS(int64_t, neb1, b1, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbb1, b1, nb)
GGML_V3_TENSOR_LOCALS(int64_t, nec0, c0, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbc0, c0, nb)
GGML_V3_TENSOR_LOCALS(int64_t, nec1, c1, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbc1, c1, nb)
GGML_V3_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_V3_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
const int64_t D = nea0;
//const int64_t N = nea1;
const int64_t M = neb01;
GGML_V3_ASSERT(ne0 == nea0);
GGML_V3_ASSERT(ne1 == nea1);
GGML_V3_ASSERT(ne2 == nea2);
GGML_V3_ASSERT(nba0 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nbb00 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nbb10 == sizeof(float));
GGML_V3_ASSERT(nbc00 == sizeof(ggml_v3_fp16_t));
GGML_V3_ASSERT(nbc10 == sizeof(float));
GGML_V3_ASSERT(neb00 == D);
GGML_V3_ASSERT(neb01 == M);
GGML_V3_ASSERT(neb10 == M);
GGML_V3_ASSERT(neb11 == 1);
GGML_V3_ASSERT(nec00 == M);
GGML_V3_ASSERT(nec01 == D);
GGML_V3_ASSERT(nec10 == D);
GGML_V3_ASSERT(nec11 == 1);
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb0 <= nb1);
GGML_V3_ASSERT(nb1 <= nb2);
GGML_V3_ASSERT(nb2 <= nb3);
if (params->type == GGML_V3_TASK_INIT) {
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// parallelize by a rows using ggml_v3_vec_dot_f32
// total rows in a
const int nr = nea1*nea2*nea3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// a indices
const int ia3 = ir/(nea2*nea1);
const int ia2 = (ir - ia3*nea2*nea1)/nea1;
const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
for (int64_t ic = 0; ic < neb01; ++ic) {
// b0 indices
const int ib03 = ia3;
const int ib02 = ia2;
const int ib01 = ic;
// S indices
const int i1 = ib01;
ggml_v3_vec_dot_f16(nea0,
S + i1,
(ggml_v3_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
(ggml_v3_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
}
ggml_v3_vec_add_f32(neb01, S, S, (float *) b1->data);
//ggml_v3_vec_gelu_f32(neb01, S, S);
ggml_v3_fp16_t * S16 = (ggml_v3_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
for (int64_t i = 0; i < M; i++) {
S16[i] = GGML_V3_FP32_TO_FP16(S[i]);
}
ggml_v3_vec_gelu_f16(neb01, S16, S16);
{
// dst indices
const int i1 = ia1;
const int i2 = ia2;
const int i3 = ia3;
for (int64_t ic = 0; ic < nec01; ++ic) {
ggml_v3_vec_dot_f16(neb01,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(ggml_v3_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),
S16);
}
ggml_v3_vec_add_f32(nec01,
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
(float *) c1->data);
}
}
}
static void ggml_v3_compute_forward_flash_ff(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * a,
const struct ggml_v3_tensor * b0,
const struct ggml_v3_tensor * b1,
const struct ggml_v3_tensor * c0,
const struct ggml_v3_tensor * c1,
struct ggml_v3_tensor * dst) {
switch (b0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst);
} break;
case GGML_V3_TYPE_F32:
{
GGML_V3_ASSERT(false); // TODO
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_flash_attn_back
static void ggml_v3_compute_forward_flash_attn_back_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * q,
const struct ggml_v3_tensor * k,
const struct ggml_v3_tensor * v,
const struct ggml_v3_tensor * d,
const bool masked,
struct ggml_v3_tensor * dst) {
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
GGML_V3_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_V3_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_V3_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_V3_TENSOR_LOCALS(int64_t, ned, d, ne)
GGML_V3_TENSOR_LOCALS(size_t, nbd, d, nb)
GGML_V3_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_V3_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
const int64_t D = neq0;
const int64_t N = neq1;
const int64_t P = nek1 - N;
const int64_t M = P + N;
const int Mup = ggml_v3_up(M, GGML_V3_SOFT_MAX_UNROLL);
const int mxDM = MAX(D, Mup);
// GGML_V3_ASSERT(ne0 == D);
// GGML_V3_ASSERT(ne1 == N);
GGML_V3_ASSERT(P >= 0);
GGML_V3_ASSERT(nbq0 == sizeof(float));
GGML_V3_ASSERT(nbk0 == sizeof(float));
GGML_V3_ASSERT(nbv0 == sizeof(float));
GGML_V3_ASSERT(neq0 == D);
GGML_V3_ASSERT(nek0 == D);
GGML_V3_ASSERT(nev1 == D);
GGML_V3_ASSERT(ned0 == D);
GGML_V3_ASSERT(neq1 == N);
GGML_V3_ASSERT(nek1 == N + P);
GGML_V3_ASSERT(nev1 == D);
GGML_V3_ASSERT(ned1 == N);
// dst cannot be transposed or permuted
GGML_V3_ASSERT(nb0 == sizeof(float));
GGML_V3_ASSERT(nb0 <= nb1);
GGML_V3_ASSERT(nb1 <= nb2);
GGML_V3_ASSERT(nb2 <= nb3);
if (params->type == GGML_V3_TASK_INIT) {
if (ith == 0) {
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
}
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int64_t elem_q = ggml_v3_nelements(q);
const int64_t elem_k = ggml_v3_nelements(k);
enum ggml_v3_type result_type = dst->type;
GGML_V3_ASSERT(ggml_v3_blck_size(result_type) == 1);
const size_t tsize = ggml_v3_type_size(result_type);
const size_t offs_q = 0;
const size_t offs_k = offs_q + GGML_V3_PAD(elem_q * tsize, GGML_V3_MEM_ALIGN);
const size_t offs_v = offs_k + GGML_V3_PAD(elem_k * tsize, GGML_V3_MEM_ALIGN);
void * grad_q = (char *) dst->data;
void * grad_k = (char *) dst->data + offs_k;
void * grad_v = (char *) dst->data + offs_v;
const size_t nbgq1 = nb0*neq0;
const size_t nbgq2 = nb0*neq0*neq1;
const size_t nbgq3 = nb0*neq0*neq1*neq2;
const size_t nbgk1 = nb0*nek0;
const size_t nbgk2 = nb0*nek0*nek1;
const size_t nbgk3 = nb0*nek0*nek1*neq2;
const size_t nbgv1 = nb0*nev0;
const size_t nbgv2 = nb0*nev0*nev1;
const size_t nbgv3 = nb0*nev0*nev1*neq2;
// parallelize by k rows using ggml_v3_vec_dot_f32
// total rows in k
const int nr = nek2*nek3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const float scale = 1.0f/sqrtf(D);
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
// how often k2 (and v2) is repeated in q2
int nrep = neq2/nek2;
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int ik3 = ir/(nek2);
const int ik2 = ir - ik3*nek2;
const int iq3 = ik3;
const int id3 = ik3;
const int iv3 = ik3;
const int iv2 = ik2;
for (int irep = 0; irep < nrep; ++irep) {
const int iq2 = ik2 + irep*nek2;
const int id2 = iq2;
// (ik2 + irep*nek2) % nek2 == ik2
for (int iq1 = 0; iq1 < neq1; ++iq1) {
const int id1 = iq1;
// not sure about CACHE_LINE_SIZE_F32..
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
for (int i = M; i < Mup; ++i) {
S[i] = -INFINITY;
}
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
for (int64_t ic = 0; ic < masked_begin; ++ic) {
// k indices
const int ik1 = ic;
// S indices
const int i1 = ik1;
ggml_v3_vec_dot_f32(neq0,
S + i1,
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
// scale
ggml_v3_vec_scale_f32(masked_begin, S, scale);
for (int64_t i = masked_begin; i < M; i++) {
S[i] = -INFINITY;
}
// softmax
// exclude known -INF S[..] values from max and loop
// dont forget to set their SM values to zero
{
float max = -INFINITY;
ggml_v3_vec_max_f32(masked_begin, &max, S);
ggml_v3_float sum = 0.0;
{
#ifdef GGML_V3_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
vvexpf(SM, SM, &Mup);
ggml_v3_vec_sum_f32(Mup, &sum, SM);
#else
uint16_t scvt[GGML_V3_SOFT_MAX_UNROLL]; UNUSED(scvt);
ggml_v3_float sump[GGML_V3_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_V3_SOFT_MAX_UNROLL) {
if (i >= masked_begin) {
break;
}
float * SR = S + i;
float * SW = SM + i;
for (int j = 0; j < GGML_V3_SOFT_MAX_UNROLL; ++j) {
if (i + j >= masked_begin) {
break;
} else if (SR[j] == -INFINITY) {
SW[j] = 0.0f;
} else {
#ifndef GGML_V3_FLASH_ATTN_EXP_FP16
const float val = expf(SR[j] - max);
#else
ggml_v3_fp16_t s = GGML_V3_FP32_TO_FP16(SR[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_V3_FP16_TO_FP32(ggml_v3_table_exp_f16[scvt[j]]);
#endif
sump[j] += (ggml_v3_float)val;
SW[j] = val;
}
}
}
for (int i = 0; i < GGML_V3_SOFT_MAX_UNROLL; i++) {
sum += sump[i];
}
#endif
}
assert(sum > 0.0);
sum = 1.0/sum;
ggml_v3_vec_scale_f32(masked_begin, SM, sum);
}
// step-by-step explanation
{
// forward-process shape grads from backward process
// parallel_for ik2,ik3:
// for irep:
// iq2 = ik2 + irep*nek2
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
// for iq1:
// kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
// vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
// S0 = -Inf [D,1,1,1]
// ~S1[i] = dot(kcur[:D,i], qcur)
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
// ~S5[i] = dot(vcur[:,i], S4)
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
// dst backward-/ grad[dst] = d
//
// output gradients with their dependencies:
//
// grad[kcur] = grad[S1].T @ qcur
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// grad[S4] = grad[S5] @ vcur
// grad[S4] = d[:D,id1,id2,id3] @ vcur
// grad[qcur] = grad[S1] @ kcur
// grad[vcur] = grad[S5].T @ S4
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
//
// in post-order:
//
// S1 = qcur @ kcur.T
// S2 = S1 * scale
// S3 = diag_mask_inf(S2, P)
// S4 = softmax(S3)
// grad[S4] = d[:D,id1,id2,id3] @ vcur
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
// grad[qcur] = grad[S1] @ kcur
// grad[kcur] = grad[S1].T @ qcur
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
//
// using less variables (SM=S4):
//
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
// SM = softmax(S)
// S = d[:D,iq1,iq2,iq3] @ vcur
// dot_SM_gradSM = dot(SM, S)
// S = SM * (S - dot(SM, S))
// S = diag_mask_zero(S, P) * scale
//
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
// grad[k][:D,:M,ik2,ik3] += S.T @ qcur
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
}
// S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
// S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
// for ic:
// S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
// exclude known future zero S[..] values from operation
ggml_v3_vec_set_f32(masked_begin, S, 0);
for (int64_t ic = 0; ic < D; ++ic) {
ggml_v3_vec_mad_f32(masked_begin,
S,
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
}
// S = SM * (S - dot(SM, S))
float dot_SM_gradSM = 0;
ggml_v3_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
ggml_v3_vec_acc1_f32(M, S, -dot_SM_gradSM);
ggml_v3_vec_mul_f32 (masked_begin, S, S, SM);
// S = diag_mask_zero(S, P) * scale
// already done by above ggml_v3_vec_set_f32
// exclude known zero S[..] values from operation
ggml_v3_vec_scale_f32(masked_begin, S, scale);
// S shape [M,1]
// SM shape [M,1]
// kcur shape [D,M]
// qcur shape [D,1]
// vcur shape [M,D]
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
// for ic:
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
// exclude known zero S[..] values from loop
for (int64_t ic = 0; ic < masked_begin; ++ic) {
ggml_v3_vec_mad_f32(D,
(float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
(float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
S[ic]);
}
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
// for ic:
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
// exclude known zero S[..] values from loop
for (int64_t ic = 0; ic < masked_begin; ++ic) {
ggml_v3_vec_mad_f32(D,
(float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
S[ic]);
}
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
// for ic:
// grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
// grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
// exclude known zero SM[..] values from mad
for (int64_t ic = 0; ic < D; ++ic) {
ggml_v3_vec_mad_f32(masked_begin,
(float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
SM,
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
}
}
}
}
}
static void ggml_v3_compute_forward_flash_attn_back(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * q,
const struct ggml_v3_tensor * k,
const struct ggml_v3_tensor * v,
const struct ggml_v3_tensor * d,
const bool masked,
struct ggml_v3_tensor * dst) {
switch (q->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_win_part
static void ggml_v3_compute_forward_win_part_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_V3_TENSOR_LOCALS(int64_t, ne, dst, ne)
const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
const int32_t w = ((const int32_t *)(dst->op_params))[2];
assert(ne00 == ne0);
assert(ne3 == nep0*nep1);
// TODO: optimize / multi-thread
for (int py = 0; py < nep1; ++py) {
for (int px = 0; px < nep0; ++px) {
const int64_t i3 = py*nep0 + px;
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int64_t i02 = py*w + i2;
const int64_t i01 = px*w + i1;
const int64_t i00 = i0;
const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
((float *) dst->data)[i] = 0.0f;
} else {
((float *) dst->data)[i] = ((float *) src0->data)[j];
}
}
}
}
}
}
}
static void ggml_v3_compute_forward_win_part(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_win_part_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_win_unpart
static void ggml_v3_compute_forward_win_unpart_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
GGML_V3_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_V3_TENSOR_LOCALS(int64_t, ne, dst, ne)
const int32_t w = ((const int32_t *)(dst->op_params))[0];
// padding
const int px = (w - ne1%w)%w;
//const int py = (w - ne2%w)%w;
const int npx = (px + ne1)/w;
//const int npy = (py + ne2)/w;
assert(ne0 == ne00);
// TODO: optimize / multi-thread
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
for (int64_t i0 = 0; i0 < ne0; ++i0) {
const int ip2 = i2/w;
const int ip1 = i1/w;
const int64_t i02 = i2%w;
const int64_t i01 = i1%w;
const int64_t i00 = i0;
const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
((float *) dst->data)[j] = ((float *) src0->data)[i];
}
}
}
}
static void ggml_v3_compute_forward_win_unpart(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_win_unpart_f32(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
//gmml_compute_forward_unary
static void ggml_v3_compute_forward_unary(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
const enum ggml_v3_unary_op op = ggml_v3_get_unary_op(dst);
switch (op) {
case GGML_V3_UNARY_OP_ABS:
{
ggml_v3_compute_forward_abs(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_SGN:
{
ggml_v3_compute_forward_sgn(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_NEG:
{
ggml_v3_compute_forward_neg(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_STEP:
{
ggml_v3_compute_forward_step(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_TANH:
{
ggml_v3_compute_forward_tanh(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_ELU:
{
ggml_v3_compute_forward_elu(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_RELU:
{
ggml_v3_compute_forward_relu(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_GELU:
{
ggml_v3_compute_forward_gelu(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_GELU_QUICK:
{
ggml_v3_compute_forward_gelu_quick(params, src0, dst);
} break;
case GGML_V3_UNARY_OP_SILU:
{
ggml_v3_compute_forward_silu(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_get_rel_pos
static void ggml_v3_compute_forward_get_rel_pos_f16(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
GGML_V3_TENSOR_UNARY_OP_LOCALS
const int64_t w = ne1;
ggml_v3_fp16_t * src0_data = (ggml_v3_fp16_t *) src0->data;
ggml_v3_fp16_t * dst_data = (ggml_v3_fp16_t *) dst->data;
for (int64_t i2 = 0; i2 < ne2; ++i2) {
for (int64_t i1 = 0; i1 < ne1; ++i1) {
const int64_t pos = (w - i1 - 1) + i2;
for (int64_t i0 = 0; i0 < ne0; ++i0) {
dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
}
}
}
}
static void ggml_v3_compute_forward_get_rel_pos(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F16:
{
ggml_v3_compute_forward_get_rel_pos_f16(params, src0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_add_rel_pos
static void ggml_v3_compute_forward_add_rel_pos_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
const struct ggml_v3_tensor * src2,
struct ggml_v3_tensor * dst) {
const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
if (!inplace && params->type == GGML_V3_TASK_INIT) {
memcpy((char *) dst->data, (char *) src0->data, ggml_v3_nbytes(dst));
return;
}
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
int64_t t0 = ggml_v3_perf_time_us();
UNUSED(t0);
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
float * src1_data = (float *) src1->data;
float * src2_data = (float *) src2->data;
float * dst_data = (float *) dst->data;
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int ith = params->ith;
const int nth = params->nth;
// total patches in dst
const int np = ne13;
// patches per thread
const int dp = (np + nth - 1)/nth;
// patch range for this thread
const int ip0 = dp*ith;
const int ip1 = MIN(ip0 + dp, np);
for (int64_t i13 = ip0; i13 < ip1; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
for (int64_t i10 = 0; i10 < ne10; ++i10) {
const int64_t jp0 = jp1 + i10;
const float src1_e = src1_data[jp0];
const float src2_e = src2_data[jp0];
const int64_t jdh = jp0 * ne10;
const int64_t jdw = jdh - (ne10 - 1) * i10;
for (int64_t j = 0; j < ne10; ++j) {
dst_data[jdh + j ] += src2_e;
dst_data[jdw + j*ne10] += src1_e;
}
}
}
}
}
}
static void ggml_v3_compute_forward_add_rel_pos(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
const struct ggml_v3_tensor * src2,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_add_rel_pos_f32(params, src0, src1, src2, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_map_unary
static void ggml_v3_compute_forward_map_unary_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst,
const ggml_v3_unary_op_f32_t fun) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert( dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
fun(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_v3_compute_forward_map_unary(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
struct ggml_v3_tensor * dst,
const ggml_v3_unary_op_f32_t fun) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_map_unary_f32(params, src0, dst, fun);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_map_binary
static void ggml_v3_compute_forward_map_binary_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst,
const ggml_v3_binary_op_f32_t fun) {
assert(params->ith == 0);
assert(ggml_v3_are_same_shape(src0, src1) && ggml_v3_are_same_shape(src0, dst));
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const int n = ggml_v3_nrows(src0);
const int nc = src0->ne[0];
assert( dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
assert(src1->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
fun(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])),
(float *) ((char *) src1->data + i*(src1->nb[1])));
}
}
static void ggml_v3_compute_forward_map_binary(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst,
const ggml_v3_binary_op_f32_t fun) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_map_custom1
static void ggml_v3_compute_forward_map_custom1_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * a,
struct ggml_v3_tensor * dst,
const ggml_v3_custom1_op_f32_t fun) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
fun(dst, a);
}
// ggml_v3_compute_forward_map_custom2
static void ggml_v3_compute_forward_map_custom2_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * a,
const struct ggml_v3_tensor * b,
struct ggml_v3_tensor * dst,
const ggml_v3_custom2_op_f32_t fun) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
fun(dst, a, b);
}
// ggml_v3_compute_forward_map_custom3
static void ggml_v3_compute_forward_map_custom3_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * a,
const struct ggml_v3_tensor * b,
const struct ggml_v3_tensor * c,
struct ggml_v3_tensor * dst,
const ggml_v3_custom3_op_f32_t fun) {
assert(params->ith == 0);
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
fun(dst, a, b, c);
}
// ggml_v3_compute_forward_map_custom1
static void ggml_v3_compute_forward_map_custom1(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * a,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
struct ggml_v3_map_custom1_op_params * p = (struct ggml_v3_map_custom1_op_params *) dst->op_params;
p->fun(dst, a, params->ith, params->nth, p->userdata);
}
// ggml_v3_compute_forward_map_custom2
static void ggml_v3_compute_forward_map_custom2(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * a,
const struct ggml_v3_tensor * b,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
struct ggml_v3_map_custom2_op_params * p = (struct ggml_v3_map_custom2_op_params *) dst->op_params;
p->fun(dst, a, b, params->ith, params->nth, p->userdata);
}
// ggml_v3_compute_forward_map_custom3
static void ggml_v3_compute_forward_map_custom3(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * a,
const struct ggml_v3_tensor * b,
const struct ggml_v3_tensor * c,
struct ggml_v3_tensor * dst) {
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
struct ggml_v3_map_custom3_op_params * p = (struct ggml_v3_map_custom3_op_params *) dst->op_params;
p->fun(dst, a, b, c, params->ith, params->nth, p->userdata);
}
// ggml_v3_compute_forward_cross_entropy_loss
static void ggml_v3_compute_forward_cross_entropy_loss_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous(src1));
GGML_V3_ASSERT(ggml_v3_is_scalar(dst));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, src1));
const int ith = params->ith;
const int nth = params->nth;
float * sums = (float *) params->wdata;
// TODO: handle transposed/permuted matrices
const int nc = src0->ne[0];
const int nr = ggml_v3_nrows(src0);
GGML_V3_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
if (params->type == GGML_V3_TASK_INIT) {
if (ith == 0) {
memset(sums, 0, sizeof(float) * (nth + nth * nc));
}
return;
}
if (params->type == GGML_V3_TASK_FINALIZE) {
if (ith == 0) {
float * dp = (float *) dst->data;
ggml_v3_vec_sum_f32(nth, dp, sums);
dp[0] *= -1.0f / (float) nr;
}
return;
}
const double eps = 1e-9;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
float * st = ((float *) params->wdata) + nth + ith*nc;
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(s0[i]));
assert(!isnan(s1[i]));
}
#endif
// soft_max
ggml_v3_float sum = 0.0;
{
float max = -INFINITY;
ggml_v3_vec_max_f32(nc, &max, s0);
uint16_t scvt; UNUSED(scvt);
for (int i = 0; i < nc; i++) {
if (s0[i] == -INFINITY) {
st[i] = 0.0f;
} else {
#ifndef GGML_V3_CROSS_ENTROPY_EXP_FP16
const float s = s0[i] - max;
const float val = expf(s);
#else
ggml_v3_fp16_t s = GGML_V3_FP32_TO_FP16(s0[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_V3_FP16_TO_FP32(ggml_v3_table_exp_f16[scvt]);
#endif
sum += (ggml_v3_float)val;
st[i] = val;
}
}
assert(sum > 0.0);
// sum = 1.0/sum;
}
// avoid log(0) by rescaling from [0..1] to [eps..1]
sum = (1.0 - eps) / sum;
ggml_v3_vec_scale_f32(nc, st, sum);
ggml_v3_vec_add1_f32(nc, st, st, eps);
ggml_v3_vec_log_f32(nc, st, st);
ggml_v3_vec_mul_f32(nc, st, st, s1);
float st_sum = 0;
ggml_v3_vec_sum_f32(nc, &st_sum, st);
sums[ith] += st_sum;
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
assert(!isnan(st[i]));
assert(!isinf(st[i]));
}
#endif
}
}
static void ggml_v3_compute_forward_cross_entropy_loss(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
// ggml_v3_compute_forward_cross_entropy_loss_back
static void ggml_v3_compute_forward_cross_entropy_loss_back_f32(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
const struct ggml_v3_tensor * opt0,
struct ggml_v3_tensor * dst) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(dst));
GGML_V3_ASSERT(ggml_v3_is_contiguous(src0));
GGML_V3_ASSERT(ggml_v3_is_contiguous(src1));
GGML_V3_ASSERT(ggml_v3_is_contiguous(opt0));
GGML_V3_ASSERT(ggml_v3_are_same_shape(src0, src1) && ggml_v3_are_same_shape(src0, dst));
const int64_t ith = params->ith;
const int64_t nth = params->nth;
if (params->type == GGML_V3_TASK_INIT || params->type == GGML_V3_TASK_FINALIZE) {
return;
}
const double eps = 1e-9;
// TODO: handle transposed/permuted matrices
const int64_t nc = src0->ne[0];
const int64_t nr = ggml_v3_nrows(src0);
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
float * d = (float *) opt0->data;
for (int64_t i1 = ir0; i1 < ir1; i1++) {
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(s0[i]));
assert(!isnan(s1[i]));
}
#endif
// soft_max
ggml_v3_float sum = 0.0;
{
float max = -INFINITY;
ggml_v3_vec_max_f32(nc, &max, s0);
uint16_t scvt; UNUSED(scvt);
for (int i = 0; i < nc; i++) {
if (s0[i] == -INFINITY) {
ds0[i] = 0.0f;
} else {
#ifndef GGML_V3_CROSS_ENTROPY_EXP_FP16
const float s = s0[i] - max;
const float val = expf(s);
#else
ggml_v3_fp16_t s = GGML_V3_FP32_TO_FP16(s0[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_V3_FP16_TO_FP32(ggml_v3_table_exp_f16[scvt]);
#endif
sum += (ggml_v3_float)val;
ds0[i] = val;
}
}
assert(sum > 0.0);
sum = (1.0 - eps)/sum;
}
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
ggml_v3_vec_scale_f32(nc, ds0, sum);
ggml_v3_vec_add1_f32(nc, ds0, ds0, eps);
ggml_v3_vec_sub_f32(nc, ds0, ds0, s1);
ggml_v3_vec_scale_f32(nc, ds0, d[0] / (float) nr);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
assert(!isnan(ds0[i]));
assert(!isinf(ds0[i]));
}
#endif
}
}
static void ggml_v3_compute_forward_cross_entropy_loss_back(
const struct ggml_v3_compute_params * params,
const struct ggml_v3_tensor * src0,
const struct ggml_v3_tensor * src1,
const struct ggml_v3_tensor * opt0,
struct ggml_v3_tensor * dst) {
switch (src0->type) {
case GGML_V3_TYPE_F32:
{
ggml_v3_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
} break;
default:
{
GGML_V3_ASSERT(false);
} break;
}
}
/////////////////////////////////
static void ggml_v3_compute_forward(struct ggml_v3_compute_params * params, struct ggml_v3_tensor * tensor) {
GGML_V3_ASSERT(params);
if (tensor->op == GGML_V3_OP_NONE) {
return;
}
#ifdef GGML_USE_CUDA
bool skip_cpu = ggml_v3_cuda_compute_forward(params, tensor);
if (skip_cpu) {
return;
}
GGML_V3_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_V3_BACKEND_CPU);
GGML_V3_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_V3_BACKEND_CPU);
#endif // GGML_USE_CUDA
switch (tensor->op) {
case GGML_V3_OP_DUP:
{
ggml_v3_compute_forward_dup(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_ADD:
{
ggml_v3_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_ADD1:
{
ggml_v3_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_ACC:
{
ggml_v3_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_SUB:
{
ggml_v3_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_MUL:
{
ggml_v3_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_DIV:
{
ggml_v3_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_SQR:
{
ggml_v3_compute_forward_sqr(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_SQRT:
{
ggml_v3_compute_forward_sqrt(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_LOG:
{
ggml_v3_compute_forward_log(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_SUM:
{
ggml_v3_compute_forward_sum(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_SUM_ROWS:
{
ggml_v3_compute_forward_sum_rows(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_MEAN:
{
ggml_v3_compute_forward_mean(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_ARGMAX:
{
ggml_v3_compute_forward_argmax(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_REPEAT:
{
ggml_v3_compute_forward_repeat(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_REPEAT_BACK:
{
ggml_v3_compute_forward_repeat_back(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_CONCAT:
{
ggml_v3_compute_forward_concat(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_SILU_BACK:
{
ggml_v3_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_NORM:
{
ggml_v3_compute_forward_norm(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_RMS_NORM:
{
ggml_v3_compute_forward_rms_norm(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_RMS_NORM_BACK:
{
ggml_v3_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_GROUP_NORM:
{
ggml_v3_compute_forward_group_norm(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_MUL_MAT:
{
ggml_v3_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_MUL_MAT_ID:
{
ggml_v3_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_OUT_PROD:
{
ggml_v3_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_SCALE:
{
ggml_v3_compute_forward_scale(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_SET:
{
ggml_v3_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_CPY:
{
ggml_v3_compute_forward_cpy(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_CONT:
{
ggml_v3_compute_forward_cont(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_RESHAPE:
{
ggml_v3_compute_forward_reshape(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_VIEW:
{
ggml_v3_compute_forward_view(params, tensor->src[0]);
} break;
case GGML_V3_OP_PERMUTE:
{
ggml_v3_compute_forward_permute(params, tensor->src[0]);
} break;
case GGML_V3_OP_TRANSPOSE:
{
ggml_v3_compute_forward_transpose(params, tensor->src[0]);
} break;
case GGML_V3_OP_GET_ROWS:
{
ggml_v3_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_GET_ROWS_BACK:
{
ggml_v3_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_DIAG:
{
ggml_v3_compute_forward_diag(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_DIAG_MASK_INF:
{
ggml_v3_compute_forward_diag_mask_inf(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_DIAG_MASK_ZERO:
{
ggml_v3_compute_forward_diag_mask_zero(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_SOFT_MAX:
{
ggml_v3_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_SOFT_MAX_BACK:
{
ggml_v3_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_ROPE:
{
ggml_v3_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_ROPE_BACK:
{
ggml_v3_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_ALIBI:
{
ggml_v3_compute_forward_alibi(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_CLAMP:
{
ggml_v3_compute_forward_clamp(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_CONV_TRANSPOSE_1D:
{
ggml_v3_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_IM2COL:
{
ggml_v3_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_CONV_TRANSPOSE_2D:
{
ggml_v3_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_V3_OP_POOL_1D:
{
ggml_v3_compute_forward_pool_1d(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_POOL_2D:
{
ggml_v3_compute_forward_pool_2d(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_UPSCALE:
{
ggml_v3_compute_forward_upscale(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_PAD:
{
ggml_v3_compute_forward_pad(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_ARGSORT:
{
ggml_v3_compute_forward_argsort(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_LEAKY_RELU:
{
ggml_v3_compute_forward_leaky_relu(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_FLASH_ATTN:
{
const int32_t t = ggml_v3_get_op_params_i32(tensor, 0);
GGML_V3_ASSERT(t == 0 || t == 1);
const bool masked = t != 0;
ggml_v3_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
} break;
case GGML_V3_OP_FLASH_FF:
{
ggml_v3_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
} break;
case GGML_V3_OP_FLASH_ATTN_BACK:
{
int32_t t = ggml_v3_get_op_params_i32(tensor, 0);
GGML_V3_ASSERT(t == 0 || t == 1);
bool masked = t != 0;
ggml_v3_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
} break;
case GGML_V3_OP_WIN_PART:
{
ggml_v3_compute_forward_win_part(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_WIN_UNPART:
{
ggml_v3_compute_forward_win_unpart(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_UNARY:
{
ggml_v3_compute_forward_unary(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_GET_REL_POS:
{
ggml_v3_compute_forward_get_rel_pos(params, tensor->src[0], tensor);
} break;
case GGML_V3_OP_ADD_REL_POS:
{
ggml_v3_compute_forward_add_rel_pos(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break;
case GGML_V3_OP_MAP_UNARY:
{
ggml_v3_unary_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_v3_compute_forward_map_unary(params, tensor->src[0], tensor, fun);
}
break;
case GGML_V3_OP_MAP_BINARY:
{
ggml_v3_binary_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_v3_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun);
}
break;
case GGML_V3_OP_MAP_CUSTOM1_F32:
{
ggml_v3_custom1_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_v3_compute_forward_map_custom1_f32(params, tensor->src[0], tensor, fun);
}
break;
case GGML_V3_OP_MAP_CUSTOM2_F32:
{
ggml_v3_custom2_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_v3_compute_forward_map_custom2_f32(params, tensor->src[0], tensor->src[1], tensor, fun);
}
break;
case GGML_V3_OP_MAP_CUSTOM3_F32:
{
ggml_v3_custom3_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_v3_compute_forward_map_custom3_f32(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor, fun);
}
break;
case GGML_V3_OP_MAP_CUSTOM1:
{
ggml_v3_compute_forward_map_custom1(params, tensor->src[0], tensor);
}
break;
case GGML_V3_OP_MAP_CUSTOM2:
{
ggml_v3_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor);
}
break;
case GGML_V3_OP_MAP_CUSTOM3:
{
ggml_v3_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
}
break;
case GGML_V3_OP_CROSS_ENTROPY_LOSS:
{
ggml_v3_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor);
}
break;
case GGML_V3_OP_CROSS_ENTROPY_LOSS_BACK:
{
ggml_v3_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
}
break;
case GGML_V3_OP_NONE:
{
// nop
} break;
case GGML_V3_OP_COUNT:
{
GGML_V3_ASSERT(false);
} break;
}
}
////////////////////////////////////////////////////////////////////////////////
static size_t ggml_v3_hash_size(size_t min_sz) {
// next primes after powers of two
static const size_t primes[] = {
2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,
2053, 4099, 8209, 16411, 32771, 65537, 131101,
262147, 524309, 1048583, 2097169, 4194319, 8388617,
16777259, 33554467, 67108879, 134217757, 268435459,
536870923, 1073741827, 2147483659
};
static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);
// find the smallest prime that is larger or equal to min_sz
size_t l = 0;
size_t r = n_primes;
while (l < r) {
size_t m = (l + r)/2;
if (primes[m] < min_sz) {
l = m + 1;
} else {
r = m;
}
}
size_t sz = l < n_primes ? primes[l] : min_sz | 1;
return sz;
}
static size_t ggml_v3_hash(const void * p) {
return (size_t)p;
}
size_t ggml_v3_hash_find(const struct ggml_v3_hash_set hash_set, struct ggml_v3_tensor * key) {
size_t h = ggml_v3_hash(key) % hash_set.size;
// linear probing
size_t i = h;
while (hash_set.keys[i] != NULL && hash_set.keys[i] != key) {
i = (i + 1) % hash_set.size;
if (i == h) {
// visited all hash table entries -> not found
return GGML_V3_HASHTABLE_FULL;
}
}
return i;
}
bool ggml_v3_hash_contains(struct ggml_v3_hash_set hash_set, struct ggml_v3_tensor * key) {
size_t i = ggml_v3_hash_find(hash_set, key);
return i != GGML_V3_HASHTABLE_FULL && hash_set.keys[i] == key;
}
size_t ggml_v3_hash_insert(struct ggml_v3_hash_set hash_set, struct ggml_v3_tensor * key) {
size_t i = ggml_v3_hash_find(hash_set, key);
GGML_V3_ASSERT(i != GGML_V3_HASHTABLE_FULL);
if (hash_set.keys[i] == key) {
return GGML_V3_HASHTABLE_ALREADY_EXISTS;
}
// insert
GGML_V3_ASSERT(hash_set.keys[i] == NULL);
hash_set.keys[i] = key;
return i;
}
size_t ggml_v3_hash_find_or_insert(struct ggml_v3_hash_set hash_set, struct ggml_v3_tensor * key) {
size_t i = ggml_v3_hash_find(hash_set, key);
GGML_V3_ASSERT(i != GGML_V3_HASHTABLE_FULL);
hash_set.keys[i] = key;
return i;
}
static struct ggml_v3_hash_set ggml_v3_hash_set_new(size_t size) {
size = ggml_v3_hash_size(size);
struct ggml_v3_hash_set result;
result.size = size;
result.keys = malloc(sizeof(struct ggml_v3_tensor *) * size);
memset(result.keys, 0, sizeof(struct ggml_v3_tensor *) * size);
return result;
}
static void ggml_v3_hash_set_free(struct ggml_v3_hash_set hash_set) {
free(hash_set.keys);
}
struct hash_map {
struct ggml_v3_hash_set set;
struct ggml_v3_tensor ** vals;
};
static struct hash_map * ggml_v3_new_hash_map(size_t size) {
struct hash_map * result = malloc(sizeof(struct hash_map));
result->set = ggml_v3_hash_set_new(size);
result->vals = malloc(sizeof(struct ggml_v3_tensor *) * result->set.size);
memset(result->vals, 0, sizeof(struct ggml_v3_tensor *) * result->set.size);
return result;
}
static void ggml_v3_hash_map_free(struct hash_map * map) {
ggml_v3_hash_set_free(map->set);
free(map->vals);
free(map);
}
// gradient checkpointing
static struct ggml_v3_tensor * ggml_v3_recompute_graph_node(
struct ggml_v3_context * ctx,
struct ggml_v3_cgraph * graph,
struct hash_map * replacements,
struct ggml_v3_tensor * node) {
if (node == NULL) {
return NULL;
}
if (node->is_param) {
return node;
}
if (!ggml_v3_hash_contains(graph->visited_hash_table, node)) {
return node;
}
int count_children = 0;
for (int k = 0; k < GGML_V3_MAX_SRC; ++k) {
if (node->src[k]) {
++count_children;
}
}
if (count_children == 0) {
return node;
}
size_t i = ggml_v3_hash_find(replacements->set, node);
GGML_V3_ASSERT(i != GGML_V3_HASHTABLE_FULL); // assert that not full
if (replacements->set.keys[i] == node) {
return replacements->vals[i];
}
struct ggml_v3_tensor * clone = ggml_v3_new_tensor(ctx, node->type, GGML_V3_MAX_DIMS, node->ne);
// insert clone into replacements
GGML_V3_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
replacements->set.keys[i] = node;
replacements->vals[i] = clone;
clone->op = node->op;
clone->grad = node->grad;
clone->is_param = node->is_param;
clone->extra = node->extra;
for (int k = 0; k < GGML_V3_MAX_DIMS; ++k) {
clone->nb[k] = node->nb[k];
}
for (int k = 0; k < GGML_V3_MAX_SRC; ++k) {
clone->src[k] = ggml_v3_recompute_graph_node(ctx, graph, replacements, node->src[k]);
}
if (node->view_src != NULL) {
clone->data = (node->view_src->data == NULL)
? NULL // view_src not yet allocated
: (char *) node->view_src->data // view_src already allocated
+ node->view_offs;
clone->view_src = node->view_src;
clone->view_offs = node->view_offs;
}
GGML_V3_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_V3_MAX_OP_PARAMS / sizeof(int32_t)));
GGML_V3_ASSERT(sizeof(node->name) == GGML_V3_MAX_NAME);
memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
ggml_v3_format_name(clone, "%s (clone)", ggml_v3_get_name(node));
return clone;
}
void ggml_v3_build_backward_gradient_checkpointing(
struct ggml_v3_context * ctx,
struct ggml_v3_cgraph * gf,
struct ggml_v3_cgraph * gb,
struct ggml_v3_cgraph * gb_tmp,
struct ggml_v3_tensor * * checkpoints,
int n_checkpoints) {
ggml_v3_graph_cpy(gf, gb_tmp);
ggml_v3_build_backward_expand(ctx, gf, gb_tmp, true);
if (n_checkpoints <= 0) {
ggml_v3_graph_cpy(gb_tmp, gb);
return;
}
struct hash_map * replacements = ggml_v3_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints);
// insert checkpoints in replacements
for (int i = 0; i < n_checkpoints; ++i) {
size_t k = ggml_v3_hash_find(replacements->set, checkpoints[i]);
GGML_V3_ASSERT(k != GGML_V3_HASHTABLE_FULL); // assert that not full
GGML_V3_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite
replacements->set.keys[k] = checkpoints[i];
replacements->vals[k] = checkpoints[i];
}
ggml_v3_graph_cpy(gf, gb);
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
// by recomputing them from checkpoints
for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
struct ggml_v3_tensor * node = gb_tmp->nodes[i];
for (int k = 0; k < GGML_V3_MAX_SRC; ++k) {
// insert new tensors recomputing src, reusing already made replacements,
// remember replacements: remember new tensors with mapping from corresponding gf nodes
// recurse for input tensors,
// unless (i.e. terminating when) input tensors are replacements (like checkpoints)
node->src[k] = ggml_v3_recompute_graph_node(ctx, gf, replacements, node->src[k]);
}
// insert rewritten backward node with replacements made into resulting backward graph gb
ggml_v3_build_forward_expand(gb, node);
}
ggml_v3_hash_map_free(replacements);
}
// functions to change gradients considering the case that input a might be initial gradient with zero value
static struct ggml_v3_tensor * ggml_v3_add_or_set(struct ggml_v3_context * ctx, struct ggml_v3_tensor * a, struct ggml_v3_tensor * b, struct ggml_v3_hash_set zero_table) {
if (ggml_v3_hash_contains(zero_table, a)) {
return b;
} else {
return ggml_v3_add_impl(ctx, a, b, false);
}
}
static struct ggml_v3_tensor * ggml_v3_acc_or_set(struct ggml_v3_context * ctx, struct ggml_v3_tensor * a, struct ggml_v3_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_v3_hash_set zero_table) {
if (ggml_v3_hash_contains(zero_table, a)) {
struct ggml_v3_tensor * a_zero = ggml_v3_scale(ctx, a, 0.0f);
return ggml_v3_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
} else {
return ggml_v3_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
}
}
static struct ggml_v3_tensor * ggml_v3_add1_or_set(struct ggml_v3_context * ctx, struct ggml_v3_tensor * a, struct ggml_v3_tensor * b, struct ggml_v3_hash_set zero_table) {
if (ggml_v3_hash_contains(zero_table, a)) {
return ggml_v3_repeat(ctx, b, a);
} else {
return ggml_v3_add1_impl(ctx, a, b, false);
}
}
static struct ggml_v3_tensor * ggml_v3_sub_or_set(struct ggml_v3_context * ctx, struct ggml_v3_tensor * a, struct ggml_v3_tensor * b, struct ggml_v3_hash_set zero_table) {
if (ggml_v3_hash_contains(zero_table, a)) {
return ggml_v3_neg(ctx, b);
} else {
return ggml_v3_sub_impl(ctx, a, b, false);
}
}
static void ggml_v3_compute_backward(struct ggml_v3_context * ctx, struct ggml_v3_tensor * tensor, struct ggml_v3_hash_set zero_table) {
struct ggml_v3_tensor * src0 = tensor->src[0];
struct ggml_v3_tensor * src1 = tensor->src[1];
switch (tensor->op) {
case GGML_V3_OP_DUP:
{
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_V3_OP_ADD:
{
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
src1->grad = ggml_v3_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
}
} break;
case GGML_V3_OP_ADD1:
{
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
src1->grad = ggml_v3_add_or_set(ctx,
src1->grad,
ggml_v3_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
zero_table);
}
} break;
case GGML_V3_OP_ACC:
{
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
const size_t nb2 = ((int32_t *) tensor->op_params)[1];
const size_t nb3 = ((int32_t *) tensor->op_params)[2];
const size_t offset = ((int32_t *) tensor->op_params)[3];
struct ggml_v3_tensor * tensor_grad_view = ggml_v3_view_4d(ctx,
tensor->grad,
src1->grad->ne[0],
src1->grad->ne[1],
src1->grad->ne[2],
src1->grad->ne[3],
nb1, nb2, nb3, offset);
src1->grad =
ggml_v3_add_or_set(ctx,
src1->grad,
ggml_v3_reshape(ctx,
ggml_v3_cont(ctx, tensor_grad_view),
src1->grad),
zero_table);
}
} break;
case GGML_V3_OP_SUB:
{
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
src1->grad = ggml_v3_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
}
} break;
case GGML_V3_OP_MUL:
{
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_mul(ctx, src1, tensor->grad),
zero_table);
}
if (src1->grad) {
src1->grad =
ggml_v3_add_or_set(ctx,
src1->grad,
ggml_v3_mul(ctx, src0, tensor->grad),
zero_table);
}
} break;
case GGML_V3_OP_DIV:
{
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_div(ctx, tensor->grad, src1),
zero_table);
}
if (src1->grad) {
src1->grad =
ggml_v3_sub_or_set(ctx,
src1->grad,
ggml_v3_mul(ctx,
tensor->grad,
ggml_v3_div(ctx, tensor, src1)),
zero_table);
}
} break;
case GGML_V3_OP_SQR:
{
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_scale(ctx,
ggml_v3_mul(ctx, src0, tensor->grad),
2.0f),
zero_table);
}
} break;
case GGML_V3_OP_SQRT:
{
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_scale(ctx,
ggml_v3_div(ctx,
tensor->grad,
tensor),
0.5f),
zero_table);
}
} break;
case GGML_V3_OP_LOG:
{
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_div(ctx,
tensor->grad,
src0),
zero_table);
}
} break;
case GGML_V3_OP_SUM:
{
if (src0->grad) {
src0->grad =
ggml_v3_add1_or_set(ctx,
src0->grad,
tensor->grad,
zero_table);
}
} break;
case GGML_V3_OP_SUM_ROWS:
{
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_repeat(ctx,
tensor->grad,
src0->grad),
zero_table);
}
} break;
case GGML_V3_OP_MEAN:
case GGML_V3_OP_ARGMAX:
{
GGML_V3_ASSERT(false); // TODO: implement
} break;
case GGML_V3_OP_REPEAT:
{
// necessary for llama
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_repeat_back(ctx, tensor->grad, src0->grad),
zero_table);
}
} break;
case GGML_V3_OP_REPEAT_BACK:
{
if (src0->grad) {
// TODO: test this
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_repeat(ctx, tensor->grad, src0->grad),
zero_table);
}
} break;
case GGML_V3_OP_CONCAT:
{
GGML_V3_ASSERT(false); // TODO: implement
} break;
case GGML_V3_OP_SILU_BACK:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_NORM:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_RMS_NORM:
{
// necessary for llama
if (src0->grad) {
float eps;
memcpy(&eps, tensor->op_params, sizeof(float));
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_rms_norm_back(ctx, src0, tensor->grad, eps),
zero_table);
}
} break;
case GGML_V3_OP_RMS_NORM_BACK:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_GROUP_NORM:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_MUL_MAT:
{
// https://cs231n.github.io/optimization-2/#staged
// # forward pass
// s0 = np.random.randn(5, 10)
// s1 = np.random.randn(10, 3)
// t = s0.dot(s1)
// # now suppose we had the gradient on t from above in the circuit
// dt = np.random.randn(*t.shape) # same shape as t
// ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
// ds1 = t.T.dot(dt)
// tensor.shape [m,p,qq,rr]
// src0.shape [n,m,q1,r1]
// src1.shape [n,p,qq,rr]
// necessary for llama
if (src0->grad) {
struct ggml_v3_tensor * s1_tg =
ggml_v3_out_prod(ctx, // [n,m,qq,rr]
src1, // [n,p,qq,rr]
tensor->grad); // [m,p,qq,rr]
const int64_t qq = s1_tg->ne[2];
const int64_t rr = s1_tg->ne[3];
const int64_t q1 = src0->ne[2];
const int64_t r1 = src0->ne[3];
const bool ne2_broadcasted = qq > q1;
const bool ne3_broadcasted = rr > r1;
if (ne2_broadcasted || ne3_broadcasted) {
// sum broadcast repetitions of s1_tg into shape of src0
s1_tg = ggml_v3_repeat_back(ctx, s1_tg, src0);
}
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad, // [n,m,q1,r1]
s1_tg, // [n,m,q1,r1]
zero_table);
}
if (src1->grad) {
src1->grad =
ggml_v3_add_or_set(ctx,
src1->grad, // [n,p,qq,rr]
// ggml_v3_mul_mat(ctx, // [n,p,qq,rr]
// ggml_v3_cont(ctx, // [m,n,q1,r1]
// ggml_v3_transpose(ctx, src0)), // [m,n,q1,r1]
// tensor->grad), // [m,p,qq,rr]
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
// // avoid transpose of src0, rather transpose smaller tensor->grad
// // and then use ggml_v3_out_prod
ggml_v3_out_prod(ctx, // [n,p,qq,rr]
src0, // [n,m,q1,r1]
ggml_v3_transpose(ctx, // [p,m,qq,rr]
tensor->grad)), // [m,p,qq,rr]
zero_table);
}
} break;
case GGML_V3_OP_MUL_MAT_ID:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_OUT_PROD:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_SCALE:
{
// necessary for llama
if (src0->grad) {
float s;
memcpy(&s, tensor->op_params, sizeof(float));
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_scale_impl(ctx, tensor->grad, s, false),
zero_table);
}
} break;
case GGML_V3_OP_SET:
{
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
const size_t nb2 = ((int32_t *) tensor->op_params)[1];
const size_t nb3 = ((int32_t *) tensor->op_params)[2];
const size_t offset = ((int32_t *) tensor->op_params)[3];
struct ggml_v3_tensor * tensor_grad_view = NULL;
if (src0->grad || src1->grad) {
GGML_V3_ASSERT(src0->type == tensor->type);
GGML_V3_ASSERT(tensor->grad->type == tensor->type);
GGML_V3_ASSERT(tensor->grad->type == src1->grad->type);
tensor_grad_view = ggml_v3_view_4d(ctx,
tensor->grad,
src1->grad->ne[0],
src1->grad->ne[1],
src1->grad->ne[2],
src1->grad->ne[3],
nb1, nb2, nb3, offset);
}
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_acc_impl(ctx,
tensor->grad,
ggml_v3_neg(ctx, tensor_grad_view),
nb1, nb2, nb3, offset, false),
zero_table);
}
if (src1->grad) {
src1->grad =
ggml_v3_add_or_set(ctx,
src1->grad,
ggml_v3_reshape(ctx,
ggml_v3_cont(ctx, tensor_grad_view),
src1->grad),
zero_table);
}
} break;
case GGML_V3_OP_CPY:
{
// necessary for llama
// cpy overwrites value of src1 by src0 and returns view(src1)
// the overwriting is mathematically equivalent to:
// tensor = src0 * 1 + src1 * 0
if (src0->grad) {
// dsrc0 = dtensor * 1
src0->grad = ggml_v3_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
if (src1->grad) {
// dsrc1 = dtensor * 0 -> noop
}
} break;
case GGML_V3_OP_CONT:
{
// same as cpy
if (src0->grad) {
GGML_V3_ASSERT(ggml_v3_is_contiguous(src0->grad));
GGML_V3_ASSERT(ggml_v3_is_contiguous(tensor->grad));
src0->grad = ggml_v3_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_V3_OP_RESHAPE:
{
// necessary for llama
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx, src0->grad,
ggml_v3_reshape(ctx,
ggml_v3_is_contiguous(tensor->grad)
? tensor->grad
: ggml_v3_cont(ctx, tensor->grad),
src0->grad),
zero_table);
}
} break;
case GGML_V3_OP_VIEW:
{
// necessary for llama
if (src0->grad) {
size_t offset;
memcpy(&offset, tensor->op_params, sizeof(offset));
size_t nb1 = tensor->nb[1];
size_t nb2 = tensor->nb[2];
size_t nb3 = tensor->nb[3];
if (src0->type != src0->grad->type) {
// gradient is typically F32, but src0 could be other type
size_t ng = ggml_v3_element_size(src0->grad);
size_t n0 = ggml_v3_element_size(src0);
GGML_V3_ASSERT(offset % n0 == 0);
GGML_V3_ASSERT(nb1 % n0 == 0);
GGML_V3_ASSERT(nb2 % n0 == 0);
GGML_V3_ASSERT(nb3 % n0 == 0);
offset = (offset / n0) * ng;
nb1 = (nb1 / n0) * ng;
nb2 = (nb2 / n0) * ng;
nb3 = (nb3 / n0) * ng;
}
src0->grad = ggml_v3_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
}
} break;
case GGML_V3_OP_PERMUTE:
{
// necessary for llama
if (src0->grad) {
int32_t * axes = (int32_t *) tensor->op_params;
int axis0 = axes[0] & 0x3;
int axis1 = axes[1] & 0x3;
int axis2 = axes[2] & 0x3;
int axis3 = axes[3] & 0x3;
int axes_backward[4] = {0,0,0,0};
axes_backward[axis0] = 0;
axes_backward[axis1] = 1;
axes_backward[axis2] = 2;
axes_backward[axis3] = 3;
src0->grad =
ggml_v3_add_or_set(ctx, src0->grad,
ggml_v3_permute(ctx,
tensor->grad,
axes_backward[0],
axes_backward[1],
axes_backward[2],
axes_backward[3]),
zero_table);
}
} break;
case GGML_V3_OP_TRANSPOSE:
{
// necessary for llama
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx, src0->grad,
ggml_v3_transpose(ctx, tensor->grad),
zero_table);
}
} break;
case GGML_V3_OP_GET_ROWS:
{
// necessary for llama (only for tokenizer)
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx, src0->grad,
// last ggml_v3_get_rows_back argument src0->grad is only
// necessary to setup correct output shape
ggml_v3_get_rows_back(ctx, tensor->grad, src1, src0->grad),
zero_table);
}
if (src1->grad) {
// noop
}
} break;
case GGML_V3_OP_GET_ROWS_BACK:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_DIAG:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_DIAG_MASK_INF:
{
// necessary for llama
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
src0->grad =
ggml_v3_add_or_set(ctx, src0->grad,
/* ggml_v3_diag_mask_inf_impl() shouldn't be here */
/* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
ggml_v3_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
zero_table);
}
} break;
case GGML_V3_OP_DIAG_MASK_ZERO:
{
// necessary for llama
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
src0->grad =
ggml_v3_add_or_set(ctx, src0->grad,
ggml_v3_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
zero_table);
}
} break;
case GGML_V3_OP_SOFT_MAX:
{
// necessary for llama
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx, src0->grad,
ggml_v3_soft_max_back(ctx, tensor->grad, tensor),
zero_table);
}
} break;
case GGML_V3_OP_SOFT_MAX_BACK:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_ROPE:
{
// necessary for llama
if (src0->grad) {
//const int n_past = ((int32_t *) tensor->op_params)[0];
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_rope_back(ctx,
tensor->grad,
src1,
n_dims,
mode,
n_ctx,
n_orig_ctx,
freq_base,
freq_scale,
ext_factor,
attn_factor,
beta_fast,
beta_slow,
xpos_base,
xpos_down),
zero_table);
}
} break;
case GGML_V3_OP_ROPE_BACK:
{
if (src0->grad) {
//const int n_past = ((int32_t *) tensor->op_params)[0];
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_rope_impl(ctx,
tensor->grad,
src1,
n_dims,
mode,
n_ctx,
n_orig_ctx,
freq_base,
freq_scale,
ext_factor,
attn_factor,
beta_fast,
beta_slow,
xpos_base,
xpos_down,
false),
zero_table);
}
} break;
case GGML_V3_OP_ALIBI:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_CLAMP:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_CONV_TRANSPOSE_1D:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_IM2COL:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_CONV_TRANSPOSE_2D:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_POOL_1D:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_POOL_2D:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_UPSCALE:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_PAD:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_ARGSORT:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_LEAKY_RELU:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_OP_FLASH_ATTN:
{
struct ggml_v3_tensor * flash_grad = NULL;
if (src0->grad || src1->grad || tensor->src[2]->grad) {
int32_t t = ggml_v3_get_op_params_i32(tensor, 0);
GGML_V3_ASSERT(t == 0 || t == 1);
bool masked = t != 0;
flash_grad =
ggml_v3_flash_attn_back(ctx,
src0,
src1,
tensor->src[2],
tensor->grad,
masked);
}
struct ggml_v3_tensor * src2 = tensor->src[2];
const int64_t elem_q = ggml_v3_nelements(src0);
const int64_t elem_k = ggml_v3_nelements(src1);
const int64_t elem_v = ggml_v3_nelements(src2);
enum ggml_v3_type result_type = flash_grad->type;
GGML_V3_ASSERT(ggml_v3_blck_size(result_type) == 1);
const size_t tsize = ggml_v3_type_size(result_type);
const size_t offs_q = 0;
const size_t offs_k = offs_q + GGML_V3_PAD(elem_q * tsize, GGML_V3_MEM_ALIGN);
const size_t offs_v = offs_k + GGML_V3_PAD(elem_k * tsize, GGML_V3_MEM_ALIGN);
if (src0->grad) {
struct ggml_v3_tensor * view_q = ggml_v3_view_1d(ctx, flash_grad, elem_q, offs_q);
struct ggml_v3_tensor * grad_q = ggml_v3_reshape(ctx, view_q, src0);
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
grad_q,
zero_table);
}
if (src1->grad) {
struct ggml_v3_tensor * view_k = ggml_v3_view_1d(ctx, flash_grad, elem_k, offs_k);
struct ggml_v3_tensor * grad_k = ggml_v3_reshape(ctx, view_k, src1);
src1->grad = ggml_v3_add_or_set(ctx,
src1->grad,
grad_k,
zero_table);
}
if (src2->grad) {
struct ggml_v3_tensor * view_v = ggml_v3_view_1d(ctx, flash_grad, elem_v, offs_v);
struct ggml_v3_tensor * grad_v = ggml_v3_reshape(ctx, view_v, src2);
src2->grad = ggml_v3_add_or_set(ctx,
src2->grad,
grad_v,
zero_table);
}
} break;
case GGML_V3_OP_FLASH_FF:
{
GGML_V3_ASSERT(false); // not supported
} break;
case GGML_V3_OP_FLASH_ATTN_BACK:
{
GGML_V3_ASSERT(false); // not supported
} break;
case GGML_V3_OP_WIN_PART:
case GGML_V3_OP_WIN_UNPART:
case GGML_V3_OP_UNARY:
{
switch (ggml_v3_get_unary_op(tensor)) {
case GGML_V3_UNARY_OP_ABS:
{
if (src0->grad) {
src0->grad =
ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_mul(ctx,
ggml_v3_sgn(ctx, src0),
tensor->grad),
zero_table);
}
} break;
case GGML_V3_UNARY_OP_SGN:
{
if (src0->grad) {
// noop
}
} break;
case GGML_V3_UNARY_OP_NEG:
{
if (src0->grad) {
src0->grad = ggml_v3_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
}
} break;
case GGML_V3_UNARY_OP_STEP:
{
if (src0->grad) {
// noop
}
} break;
case GGML_V3_UNARY_OP_TANH:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_UNARY_OP_ELU:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_UNARY_OP_RELU:
{
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_mul(ctx,
ggml_v3_step(ctx, src0),
tensor->grad),
zero_table);
}
} break;
case GGML_V3_UNARY_OP_GELU:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_UNARY_OP_GELU_QUICK:
{
GGML_V3_ASSERT(false); // TODO: not implemented
} break;
case GGML_V3_UNARY_OP_SILU:
{
// necessary for llama
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_silu_back(ctx, src0, tensor->grad),
zero_table);
}
} break;
default:
GGML_V3_ASSERT(false);
}
} break;
case GGML_V3_OP_GET_REL_POS:
case GGML_V3_OP_ADD_REL_POS:
case GGML_V3_OP_MAP_UNARY:
case GGML_V3_OP_MAP_BINARY:
case GGML_V3_OP_MAP_CUSTOM1_F32:
case GGML_V3_OP_MAP_CUSTOM2_F32:
case GGML_V3_OP_MAP_CUSTOM3_F32:
case GGML_V3_OP_MAP_CUSTOM1:
case GGML_V3_OP_MAP_CUSTOM2:
case GGML_V3_OP_MAP_CUSTOM3:
{
GGML_V3_ASSERT(false); // not supported
} break;
case GGML_V3_OP_CROSS_ENTROPY_LOSS:
{
if (src0->grad) {
src0->grad = ggml_v3_add_or_set(ctx,
src0->grad,
ggml_v3_cross_entropy_loss_back(ctx,
src0,
src1,
tensor->grad),
zero_table);
}
} break;
case GGML_V3_OP_CROSS_ENTROPY_LOSS_BACK:
{
GGML_V3_ASSERT(false); // not supported
} break;
case GGML_V3_OP_NONE:
{
// nop
} break;
case GGML_V3_OP_COUNT:
{
GGML_V3_ASSERT(false);
} break;
}
for (int i = 0; i < GGML_V3_MAX_SRC; ++i) {
if (tensor->src[i] && tensor->src[i]->grad) {
GGML_V3_ASSERT(ggml_v3_are_same_shape(tensor->src[i], tensor->src[i]->grad));
}
}
}
static void ggml_v3_visit_parents(struct ggml_v3_cgraph * cgraph, struct ggml_v3_tensor * node) {
if (node->grad == NULL) {
// this usually happens when we generate intermediate nodes from constants in the backward pass
// it can also happen during forward pass, if the user performs computations with constants
if (node->op != GGML_V3_OP_NONE) {
//GGML_V3_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
}
}
// check if already visited
if (ggml_v3_hash_insert(cgraph->visited_hash_table, node) == GGML_V3_HASHTABLE_ALREADY_EXISTS) {
return;
}
for (int i = 0; i < GGML_V3_MAX_SRC; ++i) {
const int k =
(cgraph->order == GGML_V3_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
(cgraph->order == GGML_V3_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_V3_MAX_SRC-1-i) :
/* unknown order, just fall back to using i*/ i;
if (node->src[k]) {
ggml_v3_visit_parents(cgraph, node->src[k]);
}
}
if (node->op == GGML_V3_OP_NONE && node->grad == NULL) {
// reached a leaf node, not part of the gradient graph (e.g. a constant)
GGML_V3_ASSERT(cgraph->n_leafs < cgraph->size);
if (strlen(node->name) == 0) {
ggml_v3_format_name(node, "leaf_%d", cgraph->n_leafs);
}
cgraph->leafs[cgraph->n_leafs] = node;
cgraph->n_leafs++;
} else {
GGML_V3_ASSERT(cgraph->n_nodes < cgraph->size);
if (strlen(node->name) == 0) {
ggml_v3_format_name(node, "node_%d", cgraph->n_nodes);
}
cgraph->nodes[cgraph->n_nodes] = node;
if (cgraph->grads) {
cgraph->grads[cgraph->n_nodes] = node->grad;
}
cgraph->n_nodes++;
}
}
static void ggml_v3_build_forward_impl(struct ggml_v3_cgraph * cgraph, struct ggml_v3_tensor * tensor, bool expand) {
if (!expand) {
// TODO: this branch isn't accessible anymore, maybe move this to ggml_v3_build_forward_expand
ggml_v3_graph_clear(cgraph);
}
const int n0 = cgraph->n_nodes;
UNUSED(n0);
ggml_v3_visit_parents(cgraph, tensor);
const int n_new = cgraph->n_nodes - n0;
GGML_V3_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
if (n_new > 0) {
// the last added node should always be starting point
GGML_V3_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
}
}
void ggml_v3_build_forward_expand(struct ggml_v3_cgraph * cgraph, struct ggml_v3_tensor * tensor) {
ggml_v3_build_forward_impl(cgraph, tensor, true);
}
void ggml_v3_build_backward_expand(struct ggml_v3_context * ctx, struct ggml_v3_cgraph * gf, struct ggml_v3_cgraph * gb, bool keep) {
GGML_V3_ASSERT(gf->n_nodes > 0);
// if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
if (keep) {
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_v3_tensor * node = gf->nodes[i];
if (node->grad) {
node->grad = ggml_v3_dup_tensor(ctx, node);
gf->grads[i] = node->grad;
}
}
}
// remember original gradients which start with zero values
struct ggml_v3_hash_set zero_table = ggml_v3_hash_set_new(gf->size);
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->grads[i]) {
ggml_v3_hash_insert(zero_table, gf->grads[i]);
}
}
for (int i = gf->n_nodes - 1; i >= 0; i--) {
struct ggml_v3_tensor * node = gf->nodes[i];
// inplace operations to add gradients are not created by ggml_v3_compute_backward
// use allocator to automatically make inplace operations
if (node->grad) {
ggml_v3_compute_backward(ctx, node, zero_table);
}
}
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_v3_tensor * node = gf->nodes[i];
if (node->is_param) {
GGML_V3_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
ggml_v3_build_forward_expand(gb, node->grad);
}
}
ggml_v3_hash_set_free(zero_table);
}
static size_t ggml_v3_graph_nbytes(size_t size, bool grads) {
size_t nbytes = sizeof(struct ggml_v3_cgraph);
nbytes += size * sizeof(struct ggml_v3_tensor *) * 2; // leafs + nodes
if (grads) {
nbytes += size * sizeof(struct ggml_v3_tensor *); // grads
}
nbytes += ggml_v3_hash_size(size * 2) * sizeof(struct ggml_v3_tensor *); // hash set
return nbytes;
}
size_t ggml_v3_graph_overhead_custom(size_t size, bool grads) {
return GGML_V3_OBJECT_SIZE + GGML_V3_PAD(ggml_v3_graph_nbytes(size, grads), GGML_V3_MEM_ALIGN);
}
size_t ggml_v3_graph_overhead(void) {
return ggml_v3_graph_overhead_custom(GGML_V3_DEFAULT_GRAPH_SIZE, false);
}
struct ggml_v3_cgraph * ggml_v3_new_graph_custom(struct ggml_v3_context * ctx, size_t size, bool grads) {
const size_t obj_size = ggml_v3_graph_nbytes(size, grads);
struct ggml_v3_object * obj = ggml_v3_new_object(ctx, GGML_V3_OBJECT_GRAPH, obj_size);
struct ggml_v3_cgraph * cgraph = (struct ggml_v3_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
struct ggml_v3_tensor ** data_start = (struct ggml_v3_tensor **) (cgraph + 1);
size_t hash_size = ggml_v3_hash_size(size * 2);
struct ggml_v3_tensor ** nodes_ptr = data_start;
struct ggml_v3_tensor ** leafs_ptr = nodes_ptr + size;
struct ggml_v3_tensor ** hash_keys_ptr = leafs_ptr + size;
struct ggml_v3_tensor ** grads_ptr = grads ? hash_keys_ptr + hash_size : NULL;
// check that we allocated the correct amount of memory
assert(obj_size == (size_t) (
(grads ? (char *)(grads_ptr + size) : (char *)(hash_keys_ptr + hash_size)) - (char *)cgraph));
memset(hash_keys_ptr, 0, hash_size * sizeof(struct ggml_v3_tensor *));
*cgraph = (struct ggml_v3_cgraph) {
/*.size =*/ size,
/*.n_nodes =*/ 0,
/*.n_leafs =*/ 0,
/*.nodes =*/ nodes_ptr,
/*.grads =*/ grads_ptr,
/*.leafs =*/ leafs_ptr,
/*.hash_table =*/ { hash_size, hash_keys_ptr },
/*.order =*/ GGML_V3_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
};
return cgraph;
}
struct ggml_v3_cgraph * ggml_v3_new_graph(struct ggml_v3_context * ctx) {
return ggml_v3_new_graph_custom(ctx, GGML_V3_DEFAULT_GRAPH_SIZE, false);
}
struct ggml_v3_cgraph ggml_v3_graph_view(struct ggml_v3_cgraph * cgraph0, int i0, int i1) {
struct ggml_v3_cgraph cgraph = {
/*.size =*/ 0,
/*.n_nodes =*/ i1 - i0,
/*.n_leafs =*/ 0,
/*.nodes =*/ cgraph0->nodes + i0,
/*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
/*.leafs =*/ NULL,
/*.hash_table =*/ { 0, NULL },
/*.order =*/ cgraph0->order,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
};
return cgraph;
}
void ggml_v3_graph_cpy(struct ggml_v3_cgraph * src, struct ggml_v3_cgraph * dst) {
GGML_V3_ASSERT(dst->size >= src->n_leafs);
GGML_V3_ASSERT(dst->size >= src->n_nodes);
GGML_V3_ASSERT(dst->visited_hash_table.size >= src->visited_hash_table.size);
dst->n_leafs = src->n_leafs;
dst->n_nodes = src->n_nodes;
dst->order = src->order;
for (int i = 0; i < src->n_leafs; ++i) {
dst->leafs[i] = src->leafs[i];
}
for (int i = 0; i < src->n_nodes; ++i) {
dst->nodes[i] = src->nodes[i];
}
if (src->grads) {
GGML_V3_ASSERT(dst->grads != NULL);
for (int i = 0; i < src->n_nodes; ++i) {
dst->grads[i] = src->grads[i];
}
}
for (size_t i = 0; i < src->visited_hash_table.size; ++i) {
if (src->visited_hash_table.keys[i]) {
ggml_v3_hash_insert(dst->visited_hash_table, src->visited_hash_table.keys[i]);
}
}
}
struct ggml_v3_cgraph * ggml_v3_graph_dup(struct ggml_v3_context * ctx, struct ggml_v3_cgraph * cgraph) {
struct ggml_v3_cgraph * result = ggml_v3_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
ggml_v3_graph_cpy(cgraph, result);
return result;
}
void ggml_v3_graph_reset(struct ggml_v3_cgraph * cgraph) {
GGML_V3_ASSERT(cgraph->grads != NULL);
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_v3_tensor * grad = cgraph->grads[i];
if (grad) {
ggml_v3_set_zero(grad);
}
}
}
void ggml_v3_graph_clear(struct ggml_v3_cgraph * cgraph) {
cgraph->n_leafs = 0;
cgraph->n_nodes = 0;
memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_v3_tensor *));
}
//
// thread data
//
// synchronization is done via busy loops
// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops
//
#ifdef __APPLE__
//#include <os/lock.h>
//
//typedef os_unfair_lock ggml_v3_lock_t;
//
//#define ggml_v3_lock_init(x) UNUSED(x)
//#define ggml_v3_lock_destroy(x) UNUSED(x)
//#define ggml_v3_lock_lock os_unfair_lock_lock
//#define ggml_v3_lock_unlock os_unfair_lock_unlock
//
//#define GGML_V3_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT
typedef int ggml_v3_lock_t;
#define ggml_v3_lock_init(x) UNUSED(x)
#define ggml_v3_lock_destroy(x) UNUSED(x)
#define ggml_v3_lock_lock(x) UNUSED(x)
#define ggml_v3_lock_unlock(x) UNUSED(x)
#define GGML_V3_LOCK_INITIALIZER 0
typedef pthread_t ggml_v3_thread_t;
#define ggml_v3_thread_create pthread_create
#define ggml_v3_thread_join pthread_join
#else
//typedef pthread_spinlock_t ggml_v3_lock_t;
//#define ggml_v3_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE)
//#define ggml_v3_lock_destroy pthread_spin_destroy
//#define ggml_v3_lock_lock pthread_spin_lock
//#define ggml_v3_lock_unlock pthread_spin_unlock
typedef int ggml_v3_lock_t;
#define ggml_v3_lock_init(x) UNUSED(x)
#define ggml_v3_lock_destroy(x) UNUSED(x)
#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
#define ggml_v3_lock_lock(x) _mm_pause()
#else
#define ggml_v3_lock_lock(x) UNUSED(x)
#endif
#define ggml_v3_lock_unlock(x) UNUSED(x)
#define GGML_V3_LOCK_INITIALIZER 0
typedef pthread_t ggml_v3_thread_t;
#define ggml_v3_thread_create pthread_create
#define ggml_v3_thread_join pthread_join
#endif
// Android's libc implementation "bionic" does not support setting affinity
#if defined(__linux__) && !defined(__BIONIC__)
static void set_numa_thread_affinity(int thread_n, int n_threads) {
if (!ggml_v3_is_numa()) {
return;
}
// run thread on node_num thread_n / (threads per node)
const int node_num = thread_n / ((n_threads + g_state.numa.n_nodes - 1) / g_state.numa.n_nodes);
struct ggml_v3_numa_node * node = &g_state.numa.nodes[node_num];
size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
CPU_ZERO_S(setsize, cpus);
for (size_t i = 0; i < node->n_cpus; ++i) {
CPU_SET_S(node->cpus[i], setsize, cpus);
}
int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
if (rv) {
fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",
strerror(rv));
}
CPU_FREE(cpus);
}
static void clear_numa_thread_affinity(void) {
if (!ggml_v3_is_numa()) {
return;
}
size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
CPU_ZERO_S(setsize, cpus);
for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
CPU_SET_S(i, setsize, cpus);
}
int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
if (rv) {
fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",
strerror(rv));
}
CPU_FREE(cpus);
}
#else
// TODO: Windows etc.
// (the linux implementation may also work on BSD, someone should test)
static void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads); }
static void clear_numa_thread_affinity(void) {}
#endif
struct ggml_v3_compute_state_shared {
const struct ggml_v3_cgraph * cgraph;
const struct ggml_v3_cplan * cplan;
int64_t perf_node_start_cycles;
int64_t perf_node_start_time_us;
const int n_threads;
// synchronization primitives
atomic_int n_active; // num active threads
atomic_int node_n; // active graph node
bool (*abort_callback)(void * data); // abort ggml_v3_graph_compute when true
void * abort_callback_data;
};
struct ggml_v3_compute_state {
ggml_v3_thread_t thrd;
int ith;
struct ggml_v3_compute_state_shared * shared;
};
static void ggml_v3_graph_compute_perf_stats_node(struct ggml_v3_tensor * node, const struct ggml_v3_compute_state_shared * st) {
int64_t cycles_cur = ggml_v3_perf_cycles() - st->perf_node_start_cycles;
int64_t time_us_cur = ggml_v3_perf_time_us() - st->perf_node_start_time_us;
node->perf_runs++;
node->perf_cycles += cycles_cur;
node->perf_time_us += time_us_cur;
}
static int ggml_v3_get_n_tasks(struct ggml_v3_tensor * node, int n_threads) {
int n_tasks = 0;
switch (node->op) {
case GGML_V3_OP_CPY:
case GGML_V3_OP_DUP:
case GGML_V3_OP_ADD:
case GGML_V3_OP_ADD1:
case GGML_V3_OP_ACC:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_SUB:
case GGML_V3_OP_SQR:
case GGML_V3_OP_SQRT:
case GGML_V3_OP_LOG:
case GGML_V3_OP_SUM:
case GGML_V3_OP_SUM_ROWS:
case GGML_V3_OP_MEAN:
case GGML_V3_OP_ARGMAX:
case GGML_V3_OP_REPEAT:
case GGML_V3_OP_REPEAT_BACK:
case GGML_V3_OP_LEAKY_RELU:
{
n_tasks = 1;
} break;
case GGML_V3_OP_UNARY:
switch (ggml_v3_get_unary_op(node)) {
case GGML_V3_UNARY_OP_ABS:
case GGML_V3_UNARY_OP_SGN:
case GGML_V3_UNARY_OP_NEG:
case GGML_V3_UNARY_OP_STEP:
case GGML_V3_UNARY_OP_TANH:
case GGML_V3_UNARY_OP_ELU:
case GGML_V3_UNARY_OP_RELU:
{
n_tasks = 1;
} break;
case GGML_V3_UNARY_OP_GELU:
case GGML_V3_UNARY_OP_GELU_QUICK:
case GGML_V3_UNARY_OP_SILU:
{
n_tasks = n_threads;
} break;
default:
GGML_V3_ASSERT(false);
}
break;
case GGML_V3_OP_SILU_BACK:
case GGML_V3_OP_MUL:
case GGML_V3_OP_DIV:
case GGML_V3_OP_NORM:
case GGML_V3_OP_RMS_NORM:
case GGML_V3_OP_RMS_NORM_BACK:
case GGML_V3_OP_GROUP_NORM:
case GGML_V3_OP_CONCAT:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_MUL_MAT:
{
n_tasks = n_threads;
// TODO: use different scheduling for different matrix sizes
//const int nr0 = ggml_v3_nrows(node->src[0]);
//const int nr1 = ggml_v3_nrows(node->src[1]);
//n_tasks = MIN(n_threads, MAX(1, nr0/128));
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
} break;
case GGML_V3_OP_MUL_MAT_ID:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_OUT_PROD:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_SCALE:
case GGML_V3_OP_SET:
case GGML_V3_OP_CONT:
case GGML_V3_OP_RESHAPE:
case GGML_V3_OP_VIEW:
case GGML_V3_OP_PERMUTE:
case GGML_V3_OP_TRANSPOSE:
case GGML_V3_OP_GET_ROWS:
case GGML_V3_OP_GET_ROWS_BACK:
case GGML_V3_OP_DIAG:
{
n_tasks = 1;
} break;
case GGML_V3_OP_DIAG_MASK_ZERO:
case GGML_V3_OP_DIAG_MASK_INF:
case GGML_V3_OP_SOFT_MAX_BACK:
case GGML_V3_OP_ROPE:
case GGML_V3_OP_ROPE_BACK:
case GGML_V3_OP_ADD_REL_POS:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_ALIBI:
{
n_tasks = 1; //TODO
} break;
case GGML_V3_OP_CLAMP:
{
n_tasks = 1; //TODO
} break;
case GGML_V3_OP_SOFT_MAX:
{
n_tasks = MIN(MIN(4, n_threads), ggml_v3_nrows(node->src[0]));
} break;
case GGML_V3_OP_CONV_TRANSPOSE_1D:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_IM2COL:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_CONV_TRANSPOSE_2D:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_POOL_1D:
case GGML_V3_OP_POOL_2D:
{
n_tasks = 1;
} break;
case GGML_V3_OP_UPSCALE:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_PAD:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_ARGSORT:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_FLASH_ATTN:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_FLASH_FF:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_FLASH_ATTN_BACK:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_WIN_PART:
case GGML_V3_OP_WIN_UNPART:
case GGML_V3_OP_GET_REL_POS:
case GGML_V3_OP_MAP_UNARY:
case GGML_V3_OP_MAP_BINARY:
case GGML_V3_OP_MAP_CUSTOM1_F32:
case GGML_V3_OP_MAP_CUSTOM2_F32:
case GGML_V3_OP_MAP_CUSTOM3_F32:
{
n_tasks = 1;
} break;
case GGML_V3_OP_MAP_CUSTOM1:
{
struct ggml_v3_map_custom1_op_params * p = (struct ggml_v3_map_custom1_op_params *) node->op_params;
if (p->n_tasks == GGML_V3_N_TASKS_MAX) {
n_tasks = n_threads;
} else {
n_tasks = MIN(p->n_tasks, n_threads);
}
} break;
case GGML_V3_OP_MAP_CUSTOM2:
{
struct ggml_v3_map_custom2_op_params * p = (struct ggml_v3_map_custom2_op_params *) node->op_params;
if (p->n_tasks == GGML_V3_N_TASKS_MAX) {
n_tasks = n_threads;
} else {
n_tasks = MIN(p->n_tasks, n_threads);
}
} break;
case GGML_V3_OP_MAP_CUSTOM3:
{
struct ggml_v3_map_custom3_op_params * p = (struct ggml_v3_map_custom3_op_params *) node->op_params;
if (p->n_tasks == GGML_V3_N_TASKS_MAX) {
n_tasks = n_threads;
} else {
n_tasks = MIN(p->n_tasks, n_threads);
}
} break;
case GGML_V3_OP_CROSS_ENTROPY_LOSS:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_CROSS_ENTROPY_LOSS_BACK:
{
n_tasks = n_threads;
} break;
case GGML_V3_OP_NONE:
{
n_tasks = 1;
} break;
case GGML_V3_OP_COUNT:
{
GGML_V3_ASSERT(false);
} break;
default:
{
fprintf(stderr, "%s: op not implemented: ", __func__);
if (node->op < GGML_V3_OP_COUNT) {
fprintf(stderr, "%s\n", ggml_v3_op_name(node->op));
} else {
fprintf(stderr, "%d\n", node->op);
}
GGML_V3_ASSERT(false);
} break;
}
assert(n_tasks > 0);
return n_tasks;
}
static thread_ret_t ggml_v3_graph_compute_thread(void * data) {
struct ggml_v3_compute_state * state = (struct ggml_v3_compute_state *) data;
const struct ggml_v3_cgraph * cgraph = state->shared->cgraph;
const struct ggml_v3_cplan * cplan = state->shared->cplan;
const int n_threads = state->shared->n_threads;
set_numa_thread_affinity(state->ith, n_threads);
int node_n = -1;
while (true) {
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->node_n += 1;
return (thread_ret_t) GGML_V3_EXIT_ABORTED;
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
// all other threads are finished and spinning
// do finalize and init here so we don't have synchronize again
struct ggml_v3_compute_params params = {
/*.type =*/ GGML_V3_TASK_FINALIZE,
/*.ith =*/ 0,
/*.nth =*/ 0,
/*.wsize =*/ cplan->work_size,
/*.wdata =*/ cplan->work_data,
};
if (node_n != -1) {
/* FINALIZE */
struct ggml_v3_tensor * node = cgraph->nodes[node_n];
if (GGML_V3_OP_HAS_FINALIZE[node->op]) {
params.nth = ggml_v3_get_n_tasks(node, n_threads);
ggml_v3_compute_forward(&params, node);
}
ggml_v3_graph_compute_perf_stats_node(node, state->shared);
}
// distribute new work or execute it direct if 1T
while (++node_n < cgraph->n_nodes) {
GGML_V3_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
struct ggml_v3_tensor * node = cgraph->nodes[node_n];
const int n_tasks = ggml_v3_get_n_tasks(node, n_threads);
state->shared->perf_node_start_cycles = ggml_v3_perf_cycles();
state->shared->perf_node_start_time_us = ggml_v3_perf_time_us();
params.nth = n_tasks;
/* INIT */
if (GGML_V3_OP_HAS_INIT[node->op]) {
params.type = GGML_V3_TASK_INIT;
ggml_v3_compute_forward(&params, node);
}
if (n_tasks == 1) {
// TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
// they do something more efficient than spinning (?)
params.type = GGML_V3_TASK_COMPUTE;
ggml_v3_compute_forward(&params, node);
if (GGML_V3_OP_HAS_FINALIZE[node->op]) {
params.type = GGML_V3_TASK_FINALIZE;
ggml_v3_compute_forward(&params, node);
}
ggml_v3_graph_compute_perf_stats_node(node, state->shared);
} else {
break;
}
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
break;
}
}
atomic_store(&state->shared->n_active, n_threads);
atomic_store(&state->shared->node_n, node_n);
} else {
// wait for other threads to finish
const int last = node_n;
const bool do_yield = last < 0 || cgraph->nodes[last]->op == GGML_V3_OP_MUL_MAT;
while (true) {
// TODO: this sched_yield can have significant impact on the performance - either positive or negative
// depending on the workload and the operating system.
// since it is not clear what is the best approach, it should potentially become user-configurable
// ref: https://github.com/ggerganov/ggml/issues/291
// UPD: adding the do_yield flag seems to resolve the issue universally
if (do_yield) {
sched_yield();
}
node_n = atomic_load(&state->shared->node_n);
if (node_n != last) break;
};
}
// check if we should stop
if (node_n >= cgraph->n_nodes) break;
/* COMPUTE */
struct ggml_v3_tensor * node = cgraph->nodes[node_n];
const int n_tasks = ggml_v3_get_n_tasks(node, n_threads);
struct ggml_v3_compute_params params = {
/*.type =*/ GGML_V3_TASK_COMPUTE,
/*.ith =*/ state->ith,
/*.nth =*/ n_tasks,
/*.wsize =*/ cplan->work_size,
/*.wdata =*/ cplan->work_data,
};
if (state->ith < n_tasks) {
ggml_v3_compute_forward(&params, node);
}
}
return GGML_V3_EXIT_SUCCESS;
}
struct ggml_v3_cplan ggml_v3_graph_plan(struct ggml_v3_cgraph * cgraph, int n_threads) {
if (n_threads <= 0) {
n_threads = GGML_V3_DEFAULT_N_THREADS;
}
size_t work_size = 0;
struct ggml_v3_cplan cplan;
memset(&cplan, 0, sizeof(struct ggml_v3_cplan));
// thread scheduling for the different operations + work buffer size estimation
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_v3_tensor * node = cgraph->nodes[i];
const int n_tasks = ggml_v3_get_n_tasks(node, n_threads);
size_t cur = 0;
switch (node->op) {
case GGML_V3_OP_CPY:
case GGML_V3_OP_DUP:
{
if (ggml_v3_is_quantized(node->type)) {
cur = ggml_v3_type_size(GGML_V3_TYPE_F32) * node->ne[0] * n_tasks;
}
} break;
case GGML_V3_OP_ADD:
case GGML_V3_OP_ADD1:
{
if (ggml_v3_is_quantized(node->src[0]->type)) {
cur = ggml_v3_type_size(GGML_V3_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
} break;
case GGML_V3_OP_ACC:
{
if (ggml_v3_is_quantized(node->src[0]->type)) {
cur = ggml_v3_type_size(GGML_V3_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
}
} break;
case GGML_V3_OP_MUL_MAT:
{
const enum ggml_v3_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
#if defined(GGML_USE_CLBLAST)
if (ggml_v3_cl_can_mul_mat(node->src[0], node->src[1], node)) {
cur = ggml_v3_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
} else
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_v3_compute_forward_mul_mat_use_blas(node)) {
if (node->src[0]->type != GGML_V3_TYPE_F32) {
// here we need memory just for single 2D matrix from src0
cur = ggml_v3_type_size(GGML_V3_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
}
} else
#endif
if (node->src[1]->type != vec_dot_type) {
cur = ggml_v3_row_size(vec_dot_type, ggml_v3_nelements(node->src[1]));
}
} break;
case GGML_V3_OP_MUL_MAT_ID:
{
const struct ggml_v3_tensor * src0 = node->src[2];
const struct ggml_v3_tensor * src1 = node->src[1];
const enum ggml_v3_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
cur = ggml_v3_row_size(vec_dot_type, ggml_v3_nelements(src1));
}
const int n_as = ggml_v3_get_op_params_i32(node, 1);
cur = GGML_V3_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
} break;
case GGML_V3_OP_OUT_PROD:
{
if (ggml_v3_is_quantized(node->src[0]->type)) {
cur = ggml_v3_type_size(GGML_V3_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
} break;
case GGML_V3_OP_SOFT_MAX:
{
cur = ggml_v3_type_size(GGML_V3_TYPE_F32) * node->ne[0] * n_tasks;
} break;
case GGML_V3_OP_CONV_TRANSPOSE_1D:
{
GGML_V3_ASSERT(node->src[0]->ne[3] == 1);
GGML_V3_ASSERT(node->src[1]->ne[2] == 1);
GGML_V3_ASSERT(node->src[1]->ne[3] == 1);
const int64_t ne00 = node->src[0]->ne[0]; // K
const int64_t ne01 = node->src[0]->ne[1]; // Cout
const int64_t ne02 = node->src[0]->ne[2]; // Cin
const int64_t ne10 = node->src[1]->ne[0]; // L
const int64_t ne11 = node->src[1]->ne[1]; // Cin
if (node->src[0]->type == GGML_V3_TYPE_F16 &&
node->src[1]->type == GGML_V3_TYPE_F32) {
cur += sizeof(ggml_v3_fp16_t)*ne00*ne01*ne02;
cur += sizeof(ggml_v3_fp16_t)*ne10*ne11;
} else if (node->src[0]->type == GGML_V3_TYPE_F32 &&
node->src[1]->type == GGML_V3_TYPE_F32) {
cur += sizeof(float)*ne00*ne01*ne02;
cur += sizeof(float)*ne10*ne11;
} else {
GGML_V3_ASSERT(false);
}
} break;
case GGML_V3_OP_CONV_TRANSPOSE_2D:
{
const int64_t ne00 = node->src[0]->ne[0]; // W
const int64_t ne01 = node->src[0]->ne[1]; // H
const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
const int64_t ne03 = node->src[0]->ne[3]; // Channels In
const int64_t ne10 = node->src[1]->ne[0]; // W
const int64_t ne11 = node->src[1]->ne[1]; // H
const int64_t ne12 = node->src[1]->ne[2]; // Channels In
cur += sizeof(ggml_v3_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_v3_fp16_t)*ne10*ne11*ne12;
} break;
case GGML_V3_OP_FLASH_ATTN:
{
const int64_t ne11 = ggml_v3_up(node->src[1]->ne[1], GGML_V3_SOFT_MAX_UNROLL);
if (node->src[1]->type == GGML_V3_TYPE_F32) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
} else if (node->src[1]->type == GGML_V3_TYPE_F16) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
} break;
case GGML_V3_OP_FLASH_FF:
{
if (node->src[1]->type == GGML_V3_TYPE_F32) {
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
} else if (node->src[1]->type == GGML_V3_TYPE_F16) {
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
}
} break;
case GGML_V3_OP_FLASH_ATTN_BACK:
{
const int64_t D = node->src[0]->ne[0];
const int64_t ne11 = ggml_v3_up(node->src[1]->ne[1], GGML_V3_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_v3_compute_forward_flash_attn_back
if (node->src[1]->type == GGML_V3_TYPE_F32) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
} else if (node->src[1]->type == GGML_V3_TYPE_F16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
} break;
case GGML_V3_OP_CROSS_ENTROPY_LOSS:
{
cur = ggml_v3_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
} break;
case GGML_V3_OP_COUNT:
{
GGML_V3_ASSERT(false);
} break;
default:
break;
}
work_size = MAX(work_size, cur);
}
if (work_size > 0) {
work_size += CACHE_LINE_SIZE*(n_threads - 1);
}
cplan.n_threads = n_threads;
cplan.work_size = work_size;
cplan.work_data = NULL;
return cplan;
}
int ggml_v3_graph_compute(struct ggml_v3_cgraph * cgraph, struct ggml_v3_cplan * cplan) {
{
GGML_V3_ASSERT(cplan);
GGML_V3_ASSERT(cplan->n_threads > 0);
if (cplan->work_size > 0) {
GGML_V3_ASSERT(cplan->work_data);
}
}
const int n_threads = cplan->n_threads;
struct ggml_v3_compute_state_shared state_shared = {
/*.cgraph =*/ cgraph,
/*.cgraph_plan =*/ cplan,
/*.perf_node_start_cycles =*/ 0,
/*.perf_node_start_time_us =*/ 0,
/*.n_threads =*/ n_threads,
/*.n_active =*/ n_threads,
/*.node_n =*/ -1,
/*.abort_callback =*/ NULL,
/*.abort_callback_data =*/ NULL,
};
struct ggml_v3_compute_state * workers = alloca(sizeof(struct ggml_v3_compute_state)*n_threads);
// create thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; ++j) {
workers[j] = (struct ggml_v3_compute_state) {
.thrd = 0,
.ith = j,
.shared = &state_shared,
};
const int rc = ggml_v3_thread_create(&workers[j].thrd, NULL, ggml_v3_graph_compute_thread, &workers[j]);
GGML_V3_ASSERT(rc == 0);
UNUSED(rc);
}
}
workers[0].ith = 0;
workers[0].shared = &state_shared;
const int64_t perf_start_cycles = ggml_v3_perf_cycles();
const int64_t perf_start_time_us = ggml_v3_perf_time_us();
// this is a work thread too
int compute_status = (size_t) ggml_v3_graph_compute_thread(&workers[0]);
// don't leave affinity set on the main thread
clear_numa_thread_affinity();
// join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_v3_thread_join(workers[j].thrd, NULL);
GGML_V3_ASSERT(rc == 0);
}
}
// performance stats (graph)
{
int64_t perf_cycles_cur = ggml_v3_perf_cycles() - perf_start_cycles;
int64_t perf_time_us_cur = ggml_v3_perf_time_us() - perf_start_time_us;
cgraph->perf_runs++;
cgraph->perf_cycles += perf_cycles_cur;
cgraph->perf_time_us += perf_time_us_cur;
GGML_V3_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n",
__func__, cgraph->perf_runs,
(double) perf_cycles_cur / (double) ggml_v3_cycles_per_ms(),
(double) cgraph->perf_cycles / (double) ggml_v3_cycles_per_ms() / (double) cgraph->perf_runs,
(double) perf_time_us_cur / 1000.0,
(double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
}
return compute_status;
}
void ggml_v3_graph_compute_with_ctx(struct ggml_v3_context * ctx, struct ggml_v3_cgraph * cgraph, int n_threads) {
struct ggml_v3_cplan cplan = ggml_v3_graph_plan(cgraph, n_threads);
struct ggml_v3_object * obj = ggml_v3_new_object(ctx, GGML_V3_OBJECT_WORK_BUFFER, cplan.work_size);
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
ggml_v3_graph_compute(cgraph, &cplan);
}
struct ggml_v3_tensor * ggml_v3_graph_get_tensor(struct ggml_v3_cgraph * cgraph, const char * name) {
for (int i = 0; i < cgraph->n_leafs; i++) {
struct ggml_v3_tensor * leaf = cgraph->leafs[i];
if (strcmp(leaf->name, name) == 0) {
return leaf;
}
}
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_v3_tensor * node = cgraph->nodes[i];
if (strcmp(node->name, name) == 0) {
return node;
}
}
return NULL;
}
static void ggml_v3_graph_export_leaf(const struct ggml_v3_tensor * tensor, FILE * fout) {
const int64_t * ne = tensor->ne;
const size_t * nb = tensor->nb;
fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
ggml_v3_type_name(tensor->type),
ggml_v3_op_name (tensor->op),
ggml_v3_n_dims(tensor),
ne[0], ne[1], ne[2], ne[3],
nb[0], nb[1], nb[2], nb[3],
tensor->data,
tensor->name);
}
static void ggml_v3_graph_export_node(const struct ggml_v3_tensor * tensor, const char * arg, FILE * fout) {
const int64_t * ne = tensor->ne;
const size_t * nb = tensor->nb;
fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
arg,
ggml_v3_type_name(tensor->type),
ggml_v3_op_name (tensor->op),
ggml_v3_n_dims(tensor),
ne[0], ne[1], ne[2], ne[3],
nb[0], nb[1], nb[2], nb[3],
tensor->data,
tensor->name);
}
void ggml_v3_graph_export(const struct ggml_v3_cgraph * cgraph, const char * fname) {
uint64_t size_eval = 0;
// compute size of intermediate results
// TODO: does not take into account scratch buffers !!!!
for (int i = 0; i < cgraph->n_nodes; ++i) {
size_eval += ggml_v3_nbytes_pad(cgraph->nodes[i]);
}
// print
{
FILE * fout = stdout;
fprintf(fout, "\n");
fprintf(fout, "%-16s %8x\n", "magic", GGML_V3_FILE_MAGIC);
fprintf(fout, "%-16s %8d\n", "version", GGML_V3_FILE_VERSION);
fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs);
fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes);
fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval);
// header
fprintf(fout, "\n");
fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n",
"TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME");
for (int i = 0; i < cgraph->n_leafs; ++i) {
ggml_v3_graph_export_leaf(cgraph->leafs[i], fout);
GGML_V3_ASSERT(cgraph->leafs[i]->op == GGML_V3_OP_NONE);
GGML_V3_ASSERT(cgraph->leafs[i]->src[0] == NULL);
GGML_V3_ASSERT(cgraph->leafs[i]->src[1] == NULL);
}
// header
fprintf(fout, "\n");
fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n",
"ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME");
for (int i = 0; i < cgraph->n_nodes; ++i) {
ggml_v3_graph_export_node(cgraph->nodes[i], "DST", fout);
for (int j = 0; j < GGML_V3_MAX_SRC; ++j) {
if (cgraph->nodes[i]->src[j]) {
ggml_v3_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout);
}
}
fprintf(fout, "\n");
}
fprintf(fout, "\n");
}
// write binary data
{
FILE * fout = fopen(fname, "wb");
if (!fout) {
fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
return;
}
// header
{
const uint32_t magic = GGML_V3_FILE_MAGIC;
const uint32_t version = GGML_V3_FILE_VERSION;
const uint32_t n_leafs = cgraph->n_leafs;
const uint32_t n_nodes = cgraph->n_nodes;
fwrite(&magic, sizeof(uint32_t), 1, fout);
fwrite(&version, sizeof(uint32_t), 1, fout);
fwrite(&n_leafs, sizeof(uint32_t), 1, fout);
fwrite(&n_nodes, sizeof(uint32_t), 1, fout);
fwrite(&size_eval, sizeof(uint64_t), 1, fout);
}
// leafs
{
for (int i = 0; i < cgraph->n_leafs; ++i) {
const struct ggml_v3_tensor * tensor = cgraph->leafs[i];
const uint32_t type = tensor->type;
const uint32_t op = tensor->op;
fwrite(&type, sizeof(uint32_t), 1, fout);
fwrite(&op, sizeof(uint32_t), 1, fout);
for (int j = 0; j < GGML_V3_MAX_DIMS; ++j) {
const uint64_t ne = tensor->ne[j];
const uint64_t nb = tensor->nb[j];
fwrite(&ne, sizeof(uint64_t), 1, fout);
fwrite(&nb, sizeof(uint64_t), 1, fout);
}
fwrite(tensor->name, sizeof(char), GGML_V3_MAX_NAME, fout);
fwrite(tensor->op_params, sizeof(char), GGML_V3_MAX_OP_PARAMS, fout);
// dump the data
// TODO: pad this to 32 byte boundary
{
const size_t size = ggml_v3_nbytes(tensor);
fwrite(tensor->data, sizeof(char), size, fout);
}
}
}
// nodes
{
for (int i = 0; i < cgraph->n_nodes; ++i) {
const struct ggml_v3_tensor * tensor = cgraph->nodes[i];
const uint32_t type = tensor->type;
const uint32_t op = tensor->op;
fwrite(&type, sizeof(uint32_t), 1, fout);
fwrite(&op, sizeof(uint32_t), 1, fout);
for (int j = 0; j < GGML_V3_MAX_DIMS; ++j) {
const uint64_t ne = tensor->ne[j];
const uint64_t nb = tensor->nb[j];
fwrite(&ne, sizeof(uint64_t), 1, fout);
fwrite(&nb, sizeof(uint64_t), 1, fout);
}
fwrite(tensor->name, sizeof(char), GGML_V3_MAX_NAME, fout);
fwrite(tensor->op_params, sizeof(char), GGML_V3_MAX_OP_PARAMS, fout);
// output the op arguments
{
struct ggml_v3_tensor * args[GGML_V3_MAX_SRC] = { NULL };
for (int j = 0; j < GGML_V3_MAX_SRC; ++j) {
args[j] = tensor->src[j];
}
for (int j = 0; j < GGML_V3_MAX_SRC; ++j) {
if (args[j]) {
int32_t idx = -1;
// check if leaf
{
for (int k = 0; k < cgraph->n_leafs; ++k) {
if (args[j] == cgraph->leafs[k]) {
idx = k;
break;
}
}
}
// check if node
if (idx == -1) {
for (int k = 0; k < cgraph->n_nodes; ++k) {
if (args[j] == cgraph->nodes[k]) {
idx = cgraph->n_leafs + k;
break;
}
}
}
if (idx == -1) {
fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i);
fclose(fout);
return;
}
fwrite(&idx, sizeof(int32_t), 1, fout);
} else {
const int32_t nul = -1;
fwrite(&nul, sizeof(int32_t), 1, fout);
}
}
}
}
}
fclose(fout);
}
}
struct ggml_v3_cgraph * ggml_v3_graph_import(const char * fname, struct ggml_v3_context ** ctx_data, struct ggml_v3_context ** ctx_eval) {
assert(*ctx_data == NULL);
assert(*ctx_eval == NULL);
struct ggml_v3_cgraph * result = NULL;
struct ggml_v3_tensor * data = NULL;
// read file into data
{
FILE * fin = fopen(fname, "rb");
if (!fin) {
fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
return result;
}
size_t fsize = 0;
fseek(fin, 0, SEEK_END);
fsize = ftell(fin);
fseek(fin, 0, SEEK_SET);
// create the data context
{
const size_t overhead = 1*ggml_v3_tensor_overhead();
struct ggml_v3_init_params params = {
.mem_size = fsize + overhead,
.mem_buffer = NULL,
.no_alloc = false,
};
*ctx_data = ggml_v3_init(params);
if (!*ctx_data) {
fprintf(stderr, "%s: failed to create ggml context\n", __func__);
fclose(fin);
return result;
}
}
data = ggml_v3_new_tensor_1d(*ctx_data, GGML_V3_TYPE_I8, fsize);
{
const size_t ret = fread(data->data, sizeof(char), fsize, fin);
if (ret != fsize) {
fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
fclose(fin);
return result;
}
}
fclose(fin);
}
// populate result
{
char * ptr = (char *) data->data;
const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic);
if (magic != GGML_V3_FILE_MAGIC) {
fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic);
return result;
}
const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version);
if (version != GGML_V3_FILE_VERSION) {
fprintf(stderr, "%s: invalid version number\n", __func__);
return result;
}
const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs);
const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes);
const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval);
const int graph_size = MAX(n_leafs, n_nodes);
// create the data context
{
const size_t overhead = (n_leafs + n_nodes)*ggml_v3_tensor_overhead() + ggml_v3_graph_overhead_custom(graph_size, false);
struct ggml_v3_init_params params = {
.mem_size = size_eval + overhead,
.mem_buffer = NULL,
.no_alloc = true,
};
*ctx_eval = ggml_v3_init(params);
if (!*ctx_eval) {
fprintf(stderr, "%s: failed to create ggml context\n", __func__);
return result;
}
}
result = ggml_v3_new_graph_custom(*ctx_eval, graph_size, false);
result->n_leafs = n_leafs;
result->n_nodes = n_nodes;
// leafs
{
uint32_t type;
uint32_t op;
for (uint32_t i = 0; i < n_leafs; ++i) {
type = *(const uint32_t *) ptr; ptr += sizeof(type);
op = *(const uint32_t *) ptr; ptr += sizeof(op);
int64_t ne[GGML_V3_MAX_DIMS];
size_t nb[GGML_V3_MAX_DIMS];
for (int j = 0; j < GGML_V3_MAX_DIMS; ++j) {
uint64_t ne_cur;
uint64_t nb_cur;
ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur);
nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur);
ne[j] = ne_cur;
nb[j] = nb_cur;
}
struct ggml_v3_tensor * tensor = ggml_v3_new_tensor(*ctx_eval, (enum ggml_v3_type) type, GGML_V3_MAX_DIMS, ne);
tensor->op = (enum ggml_v3_op) op;
memcpy(tensor->name, ptr, GGML_V3_MAX_NAME); ptr += GGML_V3_MAX_NAME;
memcpy(tensor->op_params, ptr, GGML_V3_MAX_OP_PARAMS); ptr += GGML_V3_MAX_OP_PARAMS;
tensor->data = (void *) ptr;
for (int j = 0; j < GGML_V3_MAX_DIMS; ++j) {
tensor->nb[j] = nb[j];
}
result->leafs[i] = tensor;
ptr += ggml_v3_nbytes(tensor);
fprintf(stderr, "%s: loaded leaf %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_v3_nbytes(tensor));
}
}
ggml_v3_set_no_alloc(*ctx_eval, false);
// nodes
{
uint32_t type;
uint32_t op;
for (uint32_t i = 0; i < n_nodes; ++i) {
type = *(const uint32_t *) ptr; ptr += sizeof(type);
op = *(const uint32_t *) ptr; ptr += sizeof(op);
enum ggml_v3_op eop = (enum ggml_v3_op) op;
int64_t ne[GGML_V3_MAX_DIMS];
size_t nb[GGML_V3_MAX_DIMS];
for (int j = 0; j < GGML_V3_MAX_DIMS; ++j) {
uint64_t ne_cur;
uint64_t nb_cur;
ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur);
nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur);
ne[j] = ne_cur;
nb[j] = nb_cur;
}
const char * ptr_name = ptr; ptr += GGML_V3_MAX_NAME;
const char * ptr_op_params = ptr; ptr += GGML_V3_MAX_OP_PARAMS;
const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_V3_MAX_SRC*sizeof(int32_t);
struct ggml_v3_tensor * args[GGML_V3_MAX_SRC] = { NULL };
// parse args
for (int j = 0; j < GGML_V3_MAX_SRC; ++j) {
const int32_t arg_idx = ptr_arg_idx[j];
if (arg_idx == -1) {
continue;
}
if (arg_idx < result->n_leafs) {
args[j] = result->leafs[arg_idx];
} else {
args[j] = result->nodes[arg_idx - result->n_leafs];
}
}
// create the tensor
// "view" operations are handled differently
// TODO: handle inplace ops - currently a copy is always made
struct ggml_v3_tensor * tensor = NULL;
switch (eop) {
// TODO: implement other view ops
case GGML_V3_OP_RESHAPE:
{
tensor = ggml_v3_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
} break;
case GGML_V3_OP_VIEW:
{
tensor = ggml_v3_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
size_t offs;
memcpy(&offs, ptr_op_params, sizeof(offs));
tensor->data = ((char *) tensor->data) + offs;
} break;
case GGML_V3_OP_TRANSPOSE:
{
tensor = ggml_v3_transpose(*ctx_eval, args[0]);
} break;
case GGML_V3_OP_PERMUTE:
{
tensor = ggml_v3_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
} break;
default:
{
tensor = ggml_v3_new_tensor(*ctx_eval, (enum ggml_v3_type) type, GGML_V3_MAX_DIMS, ne);
tensor->op = eop;
} break;
}
memcpy(tensor->name, ptr_name, GGML_V3_MAX_NAME);
memcpy(tensor->op_params, ptr_op_params, GGML_V3_MAX_OP_PARAMS);
for (int j = 0; j < GGML_V3_MAX_DIMS; ++j) {
tensor->nb[j] = nb[j];
}
for (int j = 0; j < GGML_V3_MAX_SRC; ++j) {
tensor->src[j] = args[j];
}
result->nodes[i] = tensor;
fprintf(stderr, "%s: loaded node %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_v3_nbytes(tensor));
}
}
}
return result;
}
void ggml_v3_graph_print(const struct ggml_v3_cgraph * cgraph) {
int64_t perf_total_per_op_us[GGML_V3_OP_COUNT] = {0};
GGML_V3_PRINT("=== GRAPH ===\n");
GGML_V3_PRINT("n_nodes = %d\n", cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_v3_tensor * node = cgraph->nodes[i];
perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
GGML_V3_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
i,
node->ne[0], node->ne[1], node->ne[2],
ggml_v3_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
(double) node->perf_cycles / (double) ggml_v3_cycles_per_ms(),
(double) node->perf_cycles / (double) ggml_v3_cycles_per_ms() / (double) node->perf_runs,
(double) node->perf_time_us / 1000.0,
(double) node->perf_time_us / 1000.0 / node->perf_runs);
}
GGML_V3_PRINT("n_leafs = %d\n", cgraph->n_leafs);
for (int i = 0; i < cgraph->n_leafs; i++) {
struct ggml_v3_tensor * node = cgraph->leafs[i];
GGML_V3_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
i,
node->ne[0], node->ne[1],
ggml_v3_op_name(node->op),
ggml_v3_get_name(node));
}
for (int i = 0; i < GGML_V3_OP_COUNT; i++) {
if (perf_total_per_op_us[i] == 0) {
continue;
}
GGML_V3_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_v3_op_name(i), (double) perf_total_per_op_us[i] / 1000.0);
}
GGML_V3_PRINT("========================================\n");
}
// check if node is part of the graph
static bool ggml_v3_graph_find(const struct ggml_v3_cgraph * cgraph, const struct ggml_v3_tensor * node) {
if (cgraph == NULL) {
return true;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i] == node) {
return true;
}
}
return false;
}
static struct ggml_v3_tensor * ggml_v3_graph_get_parent(const struct ggml_v3_cgraph * cgraph, const struct ggml_v3_tensor * node) {
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_v3_tensor * parent = cgraph->nodes[i];
if (parent->grad == node) {
return parent;
}
}
return NULL;
}
static void ggml_v3_graph_dump_dot_node_edge(FILE * fp, const struct ggml_v3_cgraph * gb, struct ggml_v3_tensor * node, struct ggml_v3_tensor * parent, const char * label) {
struct ggml_v3_tensor * gparent = ggml_v3_graph_get_parent(gb, node);
struct ggml_v3_tensor * gparent0 = ggml_v3_graph_get_parent(gb, parent);
fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
gparent0 ? (void *) gparent0 : (void *) parent,
gparent0 ? "g" : "x",
gparent ? (void *) gparent : (void *) node,
gparent ? "g" : "x",
gparent ? "empty" : "vee",
gparent ? "dashed" : "solid",
label);
}
static void ggml_v3_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_v3_tensor * node, struct ggml_v3_tensor * parent, const char * label) {
fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n",
(void *) parent, "x",
(void *) node, "x",
label);
}
void ggml_v3_graph_dump_dot(const struct ggml_v3_cgraph * gb, const struct ggml_v3_cgraph * gf, const char * filename) {
char color[16];
FILE * fp = fopen(filename, "w");
GGML_V3_ASSERT(fp);
fprintf(fp, "digraph G {\n");
fprintf(fp, " newrank = true;\n");
fprintf(fp, " rankdir = LR;\n");
for (int i = 0; i < gb->n_nodes; i++) {
struct ggml_v3_tensor * node = gb->nodes[i];
if (ggml_v3_graph_get_parent(gb, node) != NULL) {
continue;
}
if (node->is_param) {
snprintf(color, sizeof(color), "yellow");
} else if (node->grad) {
if (ggml_v3_graph_find(gf, node)) {
snprintf(color, sizeof(color), "green");
} else {
snprintf(color, sizeof(color), "lightblue");
}
} else {
snprintf(color, sizeof(color), "white");
}
fprintf(fp, " \"%p\" [ "
"style = filled; fillcolor = %s; shape = record; "
"label=\"",
(void *) node, color);
if (strlen(node->name) > 0) {
fprintf(fp, "%s (%s)|", node->name, ggml_v3_type_name(node->type));
} else {
fprintf(fp, "(%s)|", ggml_v3_type_name(node->type));
}
if (ggml_v3_is_matrix(node)) {
fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], ggml_v3_op_symbol(node->op));
} else {
fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_v3_op_symbol(node->op));
}
if (node->grad) {
fprintf(fp, " | <g>%s\"; ]\n", ggml_v3_op_symbol(node->grad->op));
} else {
fprintf(fp, "\"; ]\n");
}
}
for (int i = 0; i < gb->n_leafs; i++) {
struct ggml_v3_tensor * node = gb->leafs[i];
snprintf(color, sizeof(color), "pink");
fprintf(fp, " \"%p\" [ "
"style = filled; fillcolor = %s; shape = record; "
"label=\"<x>",
(void *) node, color);
if (strlen(node->name) > 0) {
fprintf(fp, "%s (%s)|", node->name, ggml_v3_type_name(node->type));
} else {
fprintf(fp, "(%s)|", ggml_v3_type_name(node->type));
}
fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
if (ggml_v3_nelements(node) < 5) {
fprintf(fp, " | (");
for (int j = 0; j < ggml_v3_nelements(node); j++) {
if (node->type == GGML_V3_TYPE_I8 || node->type == GGML_V3_TYPE_I16 || node->type == GGML_V3_TYPE_I32) {
fprintf(fp, "%d", ggml_v3_get_i32_1d(node, j));
}
else if (node->type == GGML_V3_TYPE_F32 || node->type == GGML_V3_TYPE_F16) {
fprintf(fp, "%.1e", (double)ggml_v3_get_f32_1d(node, j));
}
else {
fprintf(fp, "#");
}
if (j < ggml_v3_nelements(node) - 1) {
fprintf(fp, ", ");
}
}
fprintf(fp, ")");
}
fprintf(fp, "\"; ]\n");
}
for (int i = 0; i < gb->n_nodes; i++) {
struct ggml_v3_tensor * node = gb->nodes[i];
for (int j = 0; j < GGML_V3_MAX_SRC; j++) {
if (node->src[j]) {
char label[16];
snprintf(label, sizeof(label), "src %d", j);
ggml_v3_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
}
}
}
for (int i = 0; i < gb->n_leafs; i++) {
struct ggml_v3_tensor * node = gb->leafs[i];
for (int j = 0; j < GGML_V3_MAX_SRC; j++) {
if (node->src[j]) {
char label[16];
snprintf(label, sizeof(label), "src %d", j);
ggml_v3_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
}
}
}
fprintf(fp, "}\n");
fclose(fp);
GGML_V3_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
}
////////////////////////////////////////////////////////////////////////////////
static void ggml_v3_opt_set_params(int np, struct ggml_v3_tensor * const ps[], const float * x) {
int i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_v3_nelements(ps[p]) ;
// TODO: add function to set tensor from array
for (int64_t j = 0; j < ne; ++j) {
ggml_v3_set_f32_1d(ps[p], j, x[i++]);
}
}
}
static void ggml_v3_opt_get_params(int np, struct ggml_v3_tensor * const ps[], float * x) {
int i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_v3_nelements(ps[p]) ;
// TODO: add function to get all elements at once
for (int64_t j = 0; j < ne; ++j) {
x[i++] = ggml_v3_get_f32_1d(ps[p], j);
}
}
}
static void ggml_v3_opt_get_grad(int np, struct ggml_v3_tensor * const ps[], float * g) {
int64_t i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_v3_nelements(ps[p]) ;
// TODO: add function to get all elements at once
for (int64_t j = 0; j < ne; ++j) {
g[i++] = ggml_v3_get_f32_1d(ps[p]->grad, j);
}
}
}
static void ggml_v3_opt_acc_grad(int np, struct ggml_v3_tensor * const ps[], float * g, float scale) {
int64_t i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_v3_nelements(ps[p]) ;
// TODO: add function to get all elements at once
for (int64_t j = 0; j < ne; ++j) {
g[i++] += ggml_v3_get_f32_1d(ps[p]->grad, j) * scale;
}
}
}
//
// Using AdamW - ref: https://arxiv.org/pdf/1711.05101v3.pdf
//
// (Original Adam - ref: https://arxiv.org/pdf/1412.6980.pdf)
//
static enum ggml_v3_opt_result ggml_v3_opt_adam(
struct ggml_v3_context * ctx,
struct ggml_v3_opt_context * opt,
struct ggml_v3_opt_params params,
struct ggml_v3_tensor * f,
struct ggml_v3_cgraph * gf,
struct ggml_v3_cgraph * gb,
ggml_v3_opt_callback callback,
void * callback_data) {
GGML_V3_ASSERT(ggml_v3_is_scalar(f));
// these will store the parameters we want to optimize
struct ggml_v3_tensor * ps[GGML_V3_MAX_PARAMS];
int np = 0;
int64_t nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) {
if (gf->nodes[i]->is_param) {
GGML_V3_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
GGML_V3_ASSERT(np < GGML_V3_MAX_PARAMS);
ps[np++] = gf->nodes[i];
nx += ggml_v3_nelements(gf->nodes[i]);
}
}
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
int iter = opt->iter;
ggml_v3_opt_init(opt->ctx, opt, params, nx);
opt->iter = iter;
}
// constants
float sched = params.adam.sched;
const float alpha = params.adam.alpha;
const float decay = params.adam.decay * alpha;
const float beta1 = params.adam.beta1;
const float beta2 = params.adam.beta2;
const float eps = params.adam.eps;
const float gclip = params.adam.gclip;
const int decay_min_ndim = params.adam.decay_min_ndim;
const int n_accum = MAX(1, params.n_gradient_accumulation);
const float accum_norm = 1.0f / (float) n_accum;
float * g = opt->adam.g->data; // gradients
float * m = opt->adam.m->data; // first moment
float * v = opt->adam.v->data; // second moment
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
struct ggml_v3_cplan cplan = ggml_v3_graph_plan(gb, params.n_threads);
struct ggml_v3_object * obj = ggml_v3_new_object(ctx, GGML_V3_OBJECT_WORK_BUFFER, cplan.work_size);
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
bool cancel = false;
// compute the function value
float fx = 0;
ggml_v3_set_zero(opt->adam.g);
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
if (callback) {
callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
return GGML_V3_OPT_CANCEL;
}
}
// ggml_v3_graph_reset (gf);
ggml_v3_set_f32 (f->grad, 1.0f);
ggml_v3_graph_compute(gb, &cplan);
ggml_v3_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_v3_get_f32_1d(f, 0);
}
fx *= accum_norm;
opt->adam.fx_prev = fx;
opt->adam.fx_best = opt->adam.fx_prev;
if (pf) {
pf[opt->iter % params.past] = opt->adam.fx_prev;
}
opt->loss_before = opt->adam.fx_prev;
opt->loss_after = opt->adam.fx_prev;
// initialize
if (opt->just_initialized) {
opt->adam.n_no_improvement = 0;
opt->just_initialized = false;
}
float * fx_best = &opt->adam.fx_best;
float * fx_prev = &opt->adam.fx_prev;
int * n_no_improvement = &opt->adam.n_no_improvement;
int iter0 = opt->iter;
// run the optimizer
for (int t = 0; t < params.adam.n_iter; ++t) {
opt->iter = iter0 + t + 1;
GGML_V3_PRINT_DEBUG ("=== iter %d ===\n", t);
GGML_V3_PRINT_DEBUG ("f = %10.6f\n", ggml_v3_get_f32_1d(f, 0));
GGML_V3_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_v3_get_f32_1d(ps[0]->grad, 0));
GGML_V3_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_v3_get_f32_1d(ps[1]->grad, 0));
for (int i = 0; i < np; ++i) {
GGML_V3_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i,
ggml_v3_get_f32_1d(ps[i], 0), ggml_v3_get_f32_1d(ps[i]->grad, 0));
}
const int64_t t_start_wall = ggml_v3_time_us();
const int64_t t_start_cpu = ggml_v3_cycles();
UNUSED(t_start_wall);
UNUSED(t_start_cpu);
{
float gnorm = 1.0f;
if (gclip > 0.0f) {
// gradient clipping
ggml_v3_float sum = 0.0;
for (int64_t i = 0; i < nx; ++i) {
sum += (ggml_v3_float)(g[i]*g[i]);
}
ggml_v3_float norm = sqrt(sum);
if (norm > (ggml_v3_float) gclip) {
gnorm = (float) ((ggml_v3_float) gclip / norm);
}
}
const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter));
const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter));
int64_t i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_v3_nelements(ps[p]);
const float p_decay = ((ggml_v3_n_dims(ps[p]) >= decay_min_ndim) ? decay : 0.0f) * sched;
for (int64_t j = 0; j < ne; ++j) {
float x = ggml_v3_get_f32_1d(ps[p], j);
float g_ = g[i]*gnorm;
m[i] = m[i]*beta1 + g_*(1.0f - beta1);
v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
float mh = m[i]*beta1h;
float vh = v[i]*beta2h;
vh = sqrtf(vh) + eps;
x = x*(1.0f - p_decay) - mh/vh;
ggml_v3_set_f32_1d(ps[p], j, x);
++i;
}
}
}
fx = 0;
ggml_v3_set_zero(opt->adam.g);
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
if (callback) {
callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
return GGML_V3_OPT_CANCEL;;
}
}
// ggml_v3_graph_reset (gf);
ggml_v3_set_f32 (f->grad, 1.0f);
ggml_v3_graph_compute(gb, &cplan);
ggml_v3_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_v3_get_f32_1d(f, 0);
}
fx *= accum_norm;
opt->loss_after = fx;
// check convergence
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
GGML_V3_PRINT_DEBUG("converged\n");
return GGML_V3_OPT_OK;
}
// delta-based convergence test
if (pf != NULL) {
// need at least params.past iterations to start checking for convergence
if (params.past <= iter0 + t) {
const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
if (fabsf(rate) < params.delta) {
return GGML_V3_OPT_OK;
}
}
pf[(iter0 + t)%params.past] = fx;
}
// check for improvement
if (params.max_no_improvement > 0) {
if (fx_best[0] > fx) {
fx_best[0] = fx;
n_no_improvement[0] = 0;
} else {
++n_no_improvement[0];
if (n_no_improvement[0] >= params.max_no_improvement) {
return GGML_V3_OPT_OK;
}
}
}
fx_prev[0] = fx;
{
const int64_t t_end_cpu = ggml_v3_cycles();
GGML_V3_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC);
UNUSED(t_end_cpu);
const int64_t t_end_wall = ggml_v3_time_us();
GGML_V3_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6);
UNUSED(t_end_wall);
}
}
return GGML_V3_OPT_DID_NOT_CONVERGE;
}
//
// L-BFGS
//
// the L-BFGS implementation below is based on the following implementation:
//
// https://github.com/chokkan/liblbfgs
//
struct ggml_v3_lbfgs_iteration_data {
float alpha;
float ys;
float * s;
float * y;
};
static enum ggml_v3_opt_result linesearch_backtracking(
const struct ggml_v3_opt_params * params,
int nx,
float * x,
float * fx,
float * g,
float * d,
float * step,
const float * xp,
struct ggml_v3_tensor * f,
struct ggml_v3_cgraph * gb,
struct ggml_v3_cplan * cplan,
const int np,
struct ggml_v3_tensor * ps[],
bool * cancel,
ggml_v3_opt_callback callback,
void * callback_data) {
int count = 0;
float width = 0.0f;
float dg = 0.0f;
float finit = 0.0f;
float dginit = 0.0f;
float dgtest = 0.0f;
const float dec = 0.5f;
const float inc = 2.1f;
const int n_accum = MAX(1, params->n_gradient_accumulation);
const float accum_norm = 1.0f / (float) n_accum;
if (*step <= 0.f) {
return GGML_V3_LINESEARCH_INVALID_PARAMETERS;
}
// compute the initial gradient in the search direction
ggml_v3_vec_dot_f32(nx, &dginit, g, d);
// make sure that d points to a descent direction
if (0 < dginit) {
return GGML_V3_LINESEARCH_FAIL;
}
// initialize local variables
finit = *fx;
dgtest = params->lbfgs.ftol*dginit;
while (true) {
ggml_v3_vec_cpy_f32(nx, x, xp);
ggml_v3_vec_mad_f32(nx, x, d, *step);
// evaluate the function and gradient values
{
ggml_v3_opt_set_params(np, ps, x);
*fx = 0;
memset(g, 0, sizeof(float)*nx);
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
if (callback) {
// LBFG-S does not support learning rate -> ignore learning schedule
float sched = 0;
callback(callback_data, accum_step, &sched, cancel);
if (*cancel) {
return GGML_V3_OPT_CANCEL;
}
}
// ggml_v3_graph_reset (gf);
ggml_v3_set_f32 (f->grad, 1.0f);
ggml_v3_graph_compute(gb, cplan);
ggml_v3_opt_acc_grad(np, ps, g, accum_norm);
*fx += ggml_v3_get_f32_1d(f, 0);
}
*fx *= accum_norm;
}
++count;
if (*fx > finit + (*step)*dgtest) {
width = dec;
} else {
// Armijo condition is satisfied
if (params->lbfgs.linesearch == GGML_V3_LINESEARCH_BACKTRACKING_ARMIJO) {
return count;
}
ggml_v3_vec_dot_f32(nx, &dg, g, d);
// check the Wolfe condition
if (dg < params->lbfgs.wolfe * dginit) {
width = inc;
} else {
if(params->lbfgs.linesearch == GGML_V3_LINESEARCH_BACKTRACKING_WOLFE) {
// regular Wolfe conditions
return count;
}
if(dg > -params->lbfgs.wolfe*dginit) {
width = dec;
} else {
// strong Wolfe condition (GGML_V3_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
return count;
}
}
}
if (*step < params->lbfgs.min_step) {
return GGML_V3_LINESEARCH_MINIMUM_STEP;
}
if (*step > params->lbfgs.max_step) {
return GGML_V3_LINESEARCH_MAXIMUM_STEP;
}
if (params->lbfgs.max_linesearch <= count) {
return GGML_V3_LINESEARCH_MAXIMUM_ITERATIONS;
}
(*step) *= width;
}
GGML_V3_UNREACHABLE();
}
static enum ggml_v3_opt_result ggml_v3_opt_lbfgs(
struct ggml_v3_context * ctx,
struct ggml_v3_opt_context * opt,
struct ggml_v3_opt_params params,
struct ggml_v3_tensor * f,
struct ggml_v3_cgraph * gf,
struct ggml_v3_cgraph * gb,
ggml_v3_opt_callback callback,
void * callback_data) {
if (params.lbfgs.linesearch == GGML_V3_LINESEARCH_BACKTRACKING_WOLFE ||
params.lbfgs.linesearch == GGML_V3_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
return GGML_V3_OPT_INVALID_WOLFE;
}
}
const int m = params.lbfgs.m;
// these will store the parameters we want to optimize
struct ggml_v3_tensor * ps[GGML_V3_MAX_PARAMS];
int np = 0;
int nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) {
if (gf->nodes[i]->is_param) {
GGML_V3_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
GGML_V3_ASSERT(np < GGML_V3_MAX_PARAMS);
ps[np++] = gf->nodes[i];
nx += ggml_v3_nelements(gf->nodes[i]);
}
}
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
int iter = opt->iter;
ggml_v3_opt_init(ctx, opt, params, nx);
opt->iter = iter;
}
struct ggml_v3_cplan cplan = ggml_v3_graph_plan(gb, params.n_threads);
struct ggml_v3_object * obj = ggml_v3_new_object(ctx, GGML_V3_OBJECT_WORK_BUFFER, cplan.work_size);
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
float * x = opt->lbfgs.x->data; // current parameters
float * xp = opt->lbfgs.xp->data; // previous parameters
float * g = opt->lbfgs.g->data; // current gradient
float * gp = opt->lbfgs.gp->data; // previous gradient
float * d = opt->lbfgs.d->data; // search direction
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
const int n_accum = MAX(1, params.n_gradient_accumulation);
const float accum_norm = 1.0f / (float) n_accum;
float fx = 0.0f; // cost function value
float xnorm = 0.0f; // ||x||
float gnorm = 0.0f; // ||g||
// initialize x from the graph nodes
ggml_v3_opt_get_params(np, ps, x);
// the L-BFGS memory
float * lm_alpha = opt->lbfgs.lmal->data;
float * lm_ys = opt->lbfgs.lmys->data;
float * lm_s = opt->lbfgs.lms->data;
float * lm_y = opt->lbfgs.lmy->data;
bool cancel = false;
// evaluate the function value and its gradient
{
ggml_v3_opt_set_params(np, ps, x);
fx = 0;
memset(g, 0, sizeof(float)*nx);
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
if (callback) {
// LBFG-S does not support learning rate -> ignore learning schedule
float sched = 0;
callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
return GGML_V3_OPT_CANCEL;
}
}
// ggml_v3_graph_reset (gf);
ggml_v3_set_f32 (f->grad, 1.0f);
ggml_v3_graph_compute(gb, &cplan);
ggml_v3_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_v3_get_f32_1d(f, 0);
}
fx *= accum_norm;
opt->loss_before = fx;
opt->loss_after = fx;
}
// search direction = -gradient
ggml_v3_vec_neg_f32(nx, d, g);
// ||x||, ||g||
ggml_v3_vec_norm_f32(nx, &xnorm, x);
ggml_v3_vec_norm_f32(nx, &gnorm, g);
if (xnorm < 1.0f) {
xnorm = 1.0f;
}
// already optimized
if (gnorm/xnorm <= params.lbfgs.eps) {
return GGML_V3_OPT_OK;
}
if (opt->just_initialized) {
if (pf) {
pf[0] = fx;
}
opt->lbfgs.fx_best = fx;
// initial step
ggml_v3_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
opt->lbfgs.j = 0;
opt->lbfgs.k = 1;
opt->lbfgs.end = 0;
opt->lbfgs.n_no_improvement = 0;
opt->just_initialized = false;
}
float * fx_best = &opt->lbfgs.fx_best;
float * step = &opt->lbfgs.step;
int * j = &opt->lbfgs.j;
int * k = &opt->lbfgs.k;
int * end = &opt->lbfgs.end;
int * n_no_improvement = &opt->lbfgs.n_no_improvement;
int ls = 0;
int bound = 0;
float ys = 0.0f;
float yy = 0.0f;
float beta = 0.0f;
int it = 0;
while (true) {
// store the current position and gradient vectors
ggml_v3_vec_cpy_f32(nx, xp, x);
ggml_v3_vec_cpy_f32(nx, gp, g);
// TODO: instead of passing &cancel here, use the return code of the linesearch
// to determine if the optimization should be cancelled
// this is a simple change, but not doing this atm, since I don't have a nice
// way to test and don't want to break something with so many changes lined up
ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
if (cancel) {
return GGML_V3_OPT_CANCEL;
}
if (ls < 0) {
// linesearch failed - go back to the previous point and return
ggml_v3_vec_cpy_f32(nx, x, xp);
ggml_v3_vec_cpy_f32(nx, g, gp);
return ls;
}
opt->loss_after = fx;
ggml_v3_vec_norm_f32(nx, &xnorm, x);
ggml_v3_vec_norm_f32(nx, &gnorm, g);
GGML_V3_PRINT_DEBUG("f = %10.6f\n", ggml_v3_get_f32_1d(f, 0));
if (xnorm < 1.0f) {
xnorm = 1.0f;
}
if (gnorm/xnorm <= params.lbfgs.eps) {
// converged
return GGML_V3_OPT_OK;
}
// delta-based convergence test
if (pf != NULL) {
// need at least params.past iterations to start checking for convergence
if (params.past <= k[0]) {
const float rate = (pf[k[0]%params.past] - fx)/fx;
if (fabsf(rate) < params.delta) {
return GGML_V3_OPT_OK;
}
}
pf[k[0]%params.past] = fx;
}
// check for improvement
if (params.max_no_improvement > 0) {
if (fx < fx_best[0]) {
fx_best[0] = fx;
n_no_improvement[0] = 0;
} else {
n_no_improvement[0]++;
if (n_no_improvement[0] >= params.max_no_improvement) {
return GGML_V3_OPT_OK;
}
}
}
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
// reached the maximum number of iterations
return GGML_V3_OPT_DID_NOT_CONVERGE;
}
// update vectors s and y:
// s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
// y_{k+1} = g_{k+1} - g_{k}.
//
ggml_v3_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
ggml_v3_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
// compute scalars ys and yy:
// ys = y^t \cdot s -> 1 / \rho.
// yy = y^t \cdot y.
//
ggml_v3_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]);
ggml_v3_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
lm_ys[end[0]] = ys;
// find new search direction
// ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
bound = (m <= k[0]) ? m : k[0];
k[0]++;
it++;
end[0] = (end[0] + 1)%m;
// initialize search direction with -g
ggml_v3_vec_neg_f32(nx, d, g);
j[0] = end[0];
for (int i = 0; i < bound; ++i) {
j[0] = (j[0] + m - 1) % m;
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
ggml_v3_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
lm_alpha[j[0]] /= lm_ys[j[0]];
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
ggml_v3_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
}
ggml_v3_vec_scale_f32(nx, d, ys/yy);
for (int i = 0; i < bound; ++i) {
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
ggml_v3_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
beta /= lm_ys[j[0]];
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
ggml_v3_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
j[0] = (j[0] + 1)%m;
}
step[0] = 1.0;
}
GGML_V3_UNREACHABLE();
}
struct ggml_v3_opt_params ggml_v3_opt_default_params(enum ggml_v3_opt_type type) {
struct ggml_v3_opt_params result;
switch (type) {
case GGML_V3_OPT_ADAM:
{
result = (struct ggml_v3_opt_params) {
.type = GGML_V3_OPT_ADAM,
.graph_size = GGML_V3_DEFAULT_GRAPH_SIZE,
.n_threads = 1, // FIXME: GGML_V3_DEFAULT_N_THREADS ?
.past = 0,
.delta = 1e-5f,
.max_no_improvement = 100,
.print_forward_graph = true,
.print_backward_graph = true,
.n_gradient_accumulation = 1,
.adam = {
.n_iter = 10000,
.sched = 1.000f,
.decay = 0.0f,
.decay_min_ndim = 2,
.alpha = 0.001f,
.beta1 = 0.9f,
.beta2 = 0.999f,
.eps = 1e-8f,
.eps_f = 1e-5f,
.eps_g = 1e-3f,
.gclip = 0.0f,
},
};
} break;
case GGML_V3_OPT_LBFGS:
{
result = (struct ggml_v3_opt_params) {
.type = GGML_V3_OPT_LBFGS,
.graph_size = GGML_V3_DEFAULT_GRAPH_SIZE,
.n_threads = 1,
.past = 0,
.delta = 1e-5f,
.max_no_improvement = 0,
.print_forward_graph = true,
.print_backward_graph = true,
.n_gradient_accumulation = 1,
.lbfgs = {
.m = 6,
.n_iter = 100,
.max_linesearch = 20,
.eps = 1e-5f,
.ftol = 1e-4f,
.wolfe = 0.9f,
.min_step = 1e-20f,
.max_step = 1e+20f,
.linesearch = GGML_V3_LINESEARCH_DEFAULT,
},
};
} break;
}
return result;
}
GGML_V3_API void ggml_v3_opt_init(
struct ggml_v3_context * ctx,
struct ggml_v3_opt_context * opt,
struct ggml_v3_opt_params params,
int64_t nx) {
opt->ctx = ctx;
opt->params = params;
opt->iter = 0;
opt->nx = nx;
opt->just_initialized = true;
if (opt->ctx == NULL) {
struct ggml_v3_init_params ctx_opt_params;
if (opt->params.type == GGML_V3_OPT_ADAM) {
ctx_opt_params.mem_size = GGML_V3_MEM_ALIGN*3 + ggml_v3_tensor_overhead()*3 + ggml_v3_type_size(GGML_V3_TYPE_F32)*nx*3;
if (opt->params.past > 0) {
ctx_opt_params.mem_size += GGML_V3_MEM_ALIGN + ggml_v3_tensor_overhead() + ggml_v3_type_size(GGML_V3_TYPE_F32)*opt->params.past;
}
} else if (opt->params.type == GGML_V3_OPT_LBFGS) {
ctx_opt_params.mem_size = GGML_V3_MEM_ALIGN*9 + ggml_v3_tensor_overhead()*9 + ggml_v3_type_size(GGML_V3_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
if (opt->params.past > 0) {
ctx_opt_params.mem_size += GGML_V3_MEM_ALIGN + ggml_v3_tensor_overhead() + ggml_v3_type_size(GGML_V3_TYPE_F32)*opt->params.past;
}
}
ctx_opt_params.mem_buffer = NULL;
ctx_opt_params.no_alloc = false;
opt->ctx = ggml_v3_init(ctx_opt_params);
}
switch (opt->params.type) {
case GGML_V3_OPT_ADAM:
{
opt->adam.g = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, nx);
opt->adam.m = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, nx);
opt->adam.v = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, nx);
opt->adam.pf = params.past > 0
? ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, params.past)
: NULL;
ggml_v3_set_zero(opt->adam.m);
ggml_v3_set_zero(opt->adam.v);
if (opt->adam.pf) {
ggml_v3_set_zero(opt->adam.pf);
}
} break;
case GGML_V3_OPT_LBFGS:
{
opt->lbfgs.x = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, nx);
opt->lbfgs.xp = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, nx);
opt->lbfgs.g = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, nx);
opt->lbfgs.gp = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, nx);
opt->lbfgs.d = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, nx);
opt->lbfgs.pf = params.past > 0
? ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, params.past)
: NULL;
opt->lbfgs.lmal = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, params.lbfgs.m);
opt->lbfgs.lmys = ggml_v3_new_tensor_1d(opt->ctx, GGML_V3_TYPE_F32, params.lbfgs.m);
opt->lbfgs.lms = ggml_v3_new_tensor_2d(opt->ctx, GGML_V3_TYPE_F32, nx, params.lbfgs.m);
opt->lbfgs.lmy = ggml_v3_new_tensor_2d(opt->ctx, GGML_V3_TYPE_F32, nx, params.lbfgs.m);
ggml_v3_set_zero(opt->lbfgs.x);
ggml_v3_set_zero(opt->lbfgs.xp);
ggml_v3_set_zero(opt->lbfgs.g);
ggml_v3_set_zero(opt->lbfgs.gp);
ggml_v3_set_zero(opt->lbfgs.d);
if (opt->lbfgs.pf) {
ggml_v3_set_zero(opt->lbfgs.pf);
}
ggml_v3_set_zero(opt->lbfgs.lmal);
ggml_v3_set_zero(opt->lbfgs.lmys);
ggml_v3_set_zero(opt->lbfgs.lms);
ggml_v3_set_zero(opt->lbfgs.lmy);
} break;
}
}
enum ggml_v3_opt_result ggml_v3_opt(
struct ggml_v3_context * ctx,
struct ggml_v3_opt_params params,
struct ggml_v3_tensor * f) {
bool free_ctx = false;
if (ctx == NULL) {
struct ggml_v3_init_params params_ctx = {
.mem_size = 16*1024*1024,
.mem_buffer = NULL,
.no_alloc = false,
};
ctx = ggml_v3_init(params_ctx);
if (ctx == NULL) {
return GGML_V3_OPT_NO_CONTEXT;
}
free_ctx = true;
}
enum ggml_v3_opt_result result = GGML_V3_OPT_OK;
struct ggml_v3_opt_context * opt = (struct ggml_v3_opt_context *) alloca(sizeof(struct ggml_v3_opt_context));
ggml_v3_opt_init(ctx, opt, params, 0);
result = ggml_v3_opt_resume(ctx, opt, f);
if (free_ctx) {
ggml_v3_free(ctx);
}
return result;
}
enum ggml_v3_opt_result ggml_v3_opt_resume(
struct ggml_v3_context * ctx,
struct ggml_v3_opt_context * opt,
struct ggml_v3_tensor * f) {
// build forward + backward compute graphs
struct ggml_v3_cgraph * gf = ggml_v3_new_graph_custom(ctx, opt->params.graph_size, true);
ggml_v3_build_forward_expand(gf, f);
struct ggml_v3_cgraph * gb = ggml_v3_graph_dup(ctx, gf);
ggml_v3_build_backward_expand(ctx, gf, gb, true);
return ggml_v3_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
}
enum ggml_v3_opt_result ggml_v3_opt_resume_g(
struct ggml_v3_context * ctx,
struct ggml_v3_opt_context * opt,
struct ggml_v3_tensor * f,
struct ggml_v3_cgraph * gf,
struct ggml_v3_cgraph * gb,
ggml_v3_opt_callback callback,
void * callback_data) {
// build forward + backward compute graphs
enum ggml_v3_opt_result result = GGML_V3_OPT_OK;
switch (opt->params.type) {
case GGML_V3_OPT_ADAM:
{
result = ggml_v3_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
} break;
case GGML_V3_OPT_LBFGS:
{
result = ggml_v3_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
} break;
}
if (opt->params.print_forward_graph) {
ggml_v3_graph_print (gf);
ggml_v3_graph_dump_dot(gf, NULL, "opt-forward.dot");
}
if (opt->params.print_backward_graph) {
ggml_v3_graph_print (gb);
ggml_v3_graph_dump_dot(gb, gf, "opt-backward.dot");
}
return result;
}
////////////////////////////////////////////////////////////////////////////////
size_t ggml_v3_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK4_0 == 0);
const int nb = k / QK4_0;
for (int b = 0; b < n; b += k) {
block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0;
quantize_row_q4_0_reference(src + b, y, k);
for (int i = 0; i < nb; i++) {
for (int j = 0; j < QK4_0; j += 2) {
const uint8_t vi0 = y[i].qs[j/2] & 0x0F;
const uint8_t vi1 = y[i].qs[j/2] >> 4;
hist[vi0]++;
hist[vi1]++;
}
}
}
return (n/QK4_0*sizeof(block_q4_0));
}
size_t ggml_v3_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK4_1 == 0);
const int nb = k / QK4_1;
for (int b = 0; b < n; b += k) {
block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1;
quantize_row_q4_1_reference(src + b, y, k);
for (int i = 0; i < nb; i++) {
for (int j = 0; j < QK4_1; j += 2) {
const uint8_t vi0 = y[i].qs[j/2] & 0x0F;
const uint8_t vi1 = y[i].qs[j/2] >> 4;
hist[vi0]++;
hist[vi1]++;
}
}
}
return (n/QK4_1*sizeof(block_q4_1));
}
size_t ggml_v3_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK5_0 == 0);
const int nb = k / QK5_0;
for (int b = 0; b < n; b += k) {
block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0;
quantize_row_q5_0_reference(src + b, y, k);
for (int i = 0; i < nb; i++) {
uint32_t qh;
memcpy(&qh, &y[i].qh, sizeof(qh));
for (int j = 0; j < QK5_0; j += 2) {
const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
// cast to 16 bins
const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2;
hist[vi0]++;
hist[vi1]++;
}
}
}
return (n/QK5_0*sizeof(block_q5_0));
}
size_t ggml_v3_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK5_1 == 0);
const int nb = k / QK5_1;
for (int b = 0; b < n; b += k) {
block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1;
quantize_row_q5_1_reference(src + b, y, k);
for (int i = 0; i < nb; i++) {
uint32_t qh;
memcpy(&qh, &y[i].qh, sizeof(qh));
for (int j = 0; j < QK5_1; j += 2) {
const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
// cast to 16 bins
const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2;
hist[vi0]++;
hist[vi1]++;
}
}
}
return (n/QK5_1*sizeof(block_q5_1));
}
size_t ggml_v3_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
for (int b = 0; b < n; b += k) {
block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0;
quantize_row_q8_0_reference(src + b, y, k);
for (int i = 0; i < nb; i++) {
for (int j = 0; j < QK8_0; ++j) {
const int8_t vi = y[i].qs[j];
hist[vi/16 + 8]++;
}
}
}
return (n/QK8_0*sizeof(block_q8_0));
}
size_t ggml_v3_quantize_chunk(enum ggml_v3_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
size_t result = 0;
switch (type) {
case GGML_V3_TYPE_Q4_0:
{
GGML_V3_ASSERT(start % QK4_0 == 0);
block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
result = ggml_v3_quantize_q4_0(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q4_1:
{
GGML_V3_ASSERT(start % QK4_1 == 0);
block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
result = ggml_v3_quantize_q4_1(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q5_0:
{
GGML_V3_ASSERT(start % QK5_0 == 0);
block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
result = ggml_v3_quantize_q5_0(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q5_1:
{
GGML_V3_ASSERT(start % QK5_1 == 0);
block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
result = ggml_v3_quantize_q5_1(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q8_0:
{
GGML_V3_ASSERT(start % QK8_0 == 0);
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
result = ggml_v3_quantize_q8_0(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q2_K:
{
GGML_V3_ASSERT(start % QK_K == 0);
block_q2_K * block = (block_q2_K*)dst + start / QK_K;
result = ggml_v3_quantize_q2_K(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q3_K:
{
GGML_V3_ASSERT(start % QK_K == 0);
block_q3_K * block = (block_q3_K*)dst + start / QK_K;
result = ggml_v3_quantize_q3_K(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q4_K:
{
GGML_V3_ASSERT(start % QK_K == 0);
block_q4_K * block = (block_q4_K*)dst + start / QK_K;
result = ggml_v3_quantize_q4_K(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q5_K:
{
GGML_V3_ASSERT(start % QK_K == 0);
block_q5_K * block = (block_q5_K*)dst + start / QK_K;
result = ggml_v3_quantize_q5_K(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_Q6_K:
{
GGML_V3_ASSERT(start % QK_K == 0);
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
result = ggml_v3_quantize_q6_K(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_IQ2_XXS:
{
GGML_V3_ASSERT(start % QK_K == 0);
block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
result = ggml_v3_quantize_iq2_xxs(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_IQ2_XS:
{
GGML_V3_ASSERT(start % QK_K == 0);
block_iq2_xs * block = (block_iq2_xs*)dst + start / QK_K;
result = ggml_v3_quantize_iq2_xs(src + start, block, n, n, hist);
} break;
case GGML_V3_TYPE_F16:
{
int elemsize = sizeof(ggml_v3_fp16_t);
ggml_v3_fp32_to_fp16_row(src + start, (ggml_v3_fp16_t *)dst + start, n);
result = n * elemsize;
} break;
case GGML_V3_TYPE_F32:
{
int elemsize = sizeof(float);
result = n * elemsize;
memcpy((uint8_t *)dst + start * elemsize, src + start, result);
} break;
default:
assert(false);
}
return result;
}
////////////////////////////////////////////////////////////////////////////////
struct gguf_v3_str {
uint64_t n; // GGUFv2
char * data;
};
static const size_t GGUF_V3_TYPE_SIZE[GGUF_V3_TYPE_COUNT] = {
[GGUF_V3_TYPE_UINT8] = sizeof(uint8_t),
[GGUF_V3_TYPE_INT8] = sizeof(int8_t),
[GGUF_V3_TYPE_UINT16] = sizeof(uint16_t),
[GGUF_V3_TYPE_INT16] = sizeof(int16_t),
[GGUF_V3_TYPE_UINT32] = sizeof(uint32_t),
[GGUF_V3_TYPE_INT32] = sizeof(int32_t),
[GGUF_V3_TYPE_FLOAT32] = sizeof(float),
[GGUF_V3_TYPE_BOOL] = sizeof(bool),
[GGUF_V3_TYPE_STRING] = sizeof(struct gguf_v3_str),
[GGUF_V3_TYPE_UINT64] = sizeof(uint64_t),
[GGUF_V3_TYPE_INT64] = sizeof(int64_t),
[GGUF_V3_TYPE_FLOAT64] = sizeof(double),
[GGUF_V3_TYPE_ARRAY] = 0, // undefined
};
static_assert(GGUF_V3_TYPE_COUNT == 13, "GGUF_V3_TYPE_COUNT != 13");
static const char * GGUF_V3_TYPE_NAME[GGUF_V3_TYPE_COUNT] = {
[GGUF_V3_TYPE_UINT8] = "u8",
[GGUF_V3_TYPE_INT8] = "i8",
[GGUF_V3_TYPE_UINT16] = "u16",
[GGUF_V3_TYPE_INT16] = "i16",
[GGUF_V3_TYPE_UINT32] = "u32",
[GGUF_V3_TYPE_INT32] = "i32",
[GGUF_V3_TYPE_FLOAT32] = "f32",
[GGUF_V3_TYPE_BOOL] = "bool",
[GGUF_V3_TYPE_STRING] = "str",
[GGUF_V3_TYPE_ARRAY] = "arr",
[GGUF_V3_TYPE_UINT64] = "u64",
[GGUF_V3_TYPE_INT64] = "i64",
[GGUF_V3_TYPE_FLOAT64] = "f64",
};
static_assert(GGUF_V3_TYPE_COUNT == 13, "GGUF_V3_TYPE_COUNT != 13");
union gguf_v3_value {
uint8_t uint8;
int8_t int8;
uint16_t uint16;
int16_t int16;
uint32_t uint32;
int32_t int32;
float float32;
uint64_t uint64;
int64_t int64;
double float64;
bool bool_;
struct gguf_v3_str str;
struct {
enum gguf_v3_type type;
uint64_t n; // GGUFv2
void * data;
} arr;
};
struct gguf_v3_kv {
struct gguf_v3_str key;
enum gguf_v3_type type;
union gguf_v3_value value;
};
struct gguf_v3_header {
char magic[4];
uint32_t version;
uint64_t n_tensors; // GGUFv2
uint64_t n_kv; // GGUFv2
};
struct gguf_v3_tensor_info {
struct gguf_v3_str name;
uint32_t n_dims;
uint64_t ne[GGML_V3_MAX_DIMS];
enum ggml_v3_type type;
uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
// for writing API
const void * data;
size_t size;
};
struct gguf_v3_context {
struct gguf_v3_header header;
struct gguf_v3_kv * kv;
struct gguf_v3_tensor_info * infos;
size_t alignment;
size_t offset; // offset of `data` from beginning of file
size_t size; // size of `data` in bytes
//uint8_t * padding;
void * data;
};
static bool gguf_v3_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
const size_t n = fread(dst, 1, size, file);
*offset += n;
return n == size;
}
// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
static bool gguf_v3_fread_str_cur(FILE * file, struct gguf_v3_str * p, size_t * offset) {
p->n = 0;
p->data = NULL;
bool ok = true;
ok = ok && gguf_v3_fread_el(file, &p->n, sizeof(p->n), offset); p->data = calloc(p->n + 1, 1);
ok = ok && gguf_v3_fread_el(file, p->data, p->n, offset);
return ok;
}
static bool gguf_v3_fread_str_v1(FILE * file, struct gguf_v3_str * p, size_t * offset) {
p->n = 0;
p->data = NULL;
bool ok = true;
uint32_t n = 0;
ok = ok && gguf_v3_fread_el(file, &n, sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n;
ok = ok && gguf_v3_fread_el(file, p->data, p->n, offset);
return ok;
}
struct gguf_v3_context * gguf_v3_init_empty(void) {
struct gguf_v3_context * ctx = GGML_V3_ALIGNED_MALLOC(sizeof(struct gguf_v3_context));
memcpy(ctx->header.magic, GGUF_V3_MAGIC, sizeof(ctx->header.magic));
ctx->header.version = GGUF_V3_VERSION;
ctx->header.n_tensors = 0;
ctx->header.n_kv = 0;
ctx->kv = NULL;
ctx->infos = NULL;
ctx->alignment = GGUF_V3_DEFAULT_ALIGNMENT;
ctx->offset = 0;
ctx->size = 0;
ctx->data = NULL;
return ctx;
}
struct gguf_v3_context * gguf_v3_init_from_file(const char * fname, struct gguf_v3_init_params params) {
FILE * file = fopen(fname, "rb");
if (!file) {
return NULL;
}
// offset from start of file
size_t offset = 0;
char magic[4];
// check the magic before making allocations
{
gguf_v3_fread_el(file, &magic, sizeof(magic), &offset);
for (uint32_t i = 0; i < sizeof(magic); i++) {
if (magic[i] != GGUF_V3_MAGIC[i]) {
fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
fclose(file);
return NULL;
}
}
}
bool ok = true;
struct gguf_v3_context * ctx = GGML_V3_ALIGNED_MALLOC(sizeof(struct gguf_v3_context));
// read the header
{
strncpy(ctx->header.magic, magic, 4);
ctx->kv = NULL;
ctx->infos = NULL;
ctx->data = NULL;
ok = ok && gguf_v3_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset);
if (ctx->header.version == 1) {
// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
uint32_t n_tensors = 0;
uint32_t n_kv = 0;
ok = ok && gguf_v3_fread_el(file, &n_tensors, sizeof(n_tensors), &offset);
ok = ok && gguf_v3_fread_el(file, &n_kv, sizeof(n_kv), &offset);
ctx->header.n_tensors = n_tensors;
ctx->header.n_kv = n_kv;
} else {
ok = ok && gguf_v3_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
ok = ok && gguf_v3_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
}
if (ctx->header.version == 1) {
fprintf(stderr, "%s: GGUFv1 is deprecated. please update if possible.\n", __func__);
}
// sanity-checks to prevent from integer/buffer overflows
ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_v3_tensor_info));
ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_v3_tensor_overhead());
ok = ok && (ctx->header.n_kv < (SIZE_MAX/2)/sizeof(struct gguf_v3_kv));
if (!ok) {
fprintf(stderr, "%s: failed to read header\n", __func__);
fclose(file);
gguf_v3_free(ctx);
return NULL;
}
}
// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
bool (* gguf_v3_fread_str)(FILE *, struct gguf_v3_str *, size_t *) = gguf_v3_fread_str_cur;
if (ctx->header.version == 1) {
gguf_v3_fread_str = gguf_v3_fread_str_v1;
}
// read the kv pairs
{
ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_v3_kv));
for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
struct gguf_v3_kv * kv = &ctx->kv[i];
//fprintf(stderr, "%s: reading kv %d\n", __func__, i);
ok = ok && gguf_v3_fread_str(file, &kv->key, &offset);
ok = ok && gguf_v3_fread_el (file, &kv->type, sizeof(kv->type), &offset);
//fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
switch (kv->type) {
case GGUF_V3_TYPE_UINT8: ok = ok && gguf_v3_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break;
case GGUF_V3_TYPE_INT8: ok = ok && gguf_v3_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break;
case GGUF_V3_TYPE_UINT16: ok = ok && gguf_v3_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break;
case GGUF_V3_TYPE_INT16: ok = ok && gguf_v3_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break;
case GGUF_V3_TYPE_UINT32: ok = ok && gguf_v3_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break;
case GGUF_V3_TYPE_INT32: ok = ok && gguf_v3_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break;
case GGUF_V3_TYPE_FLOAT32: ok = ok && gguf_v3_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
case GGUF_V3_TYPE_UINT64: ok = ok && gguf_v3_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break;
case GGUF_V3_TYPE_INT64: ok = ok && gguf_v3_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break;
case GGUF_V3_TYPE_FLOAT64: ok = ok && gguf_v3_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
case GGUF_V3_TYPE_BOOL: ok = ok && gguf_v3_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break;
case GGUF_V3_TYPE_STRING: ok = ok && gguf_v3_fread_str(file, &kv->value.str, &offset); break;
case GGUF_V3_TYPE_ARRAY:
{
ok = ok && gguf_v3_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
if (ctx->header.version == 1) {
// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
uint32_t n = 0;
ok = ok && gguf_v3_fread_el(file, &n, sizeof(n), &offset);
kv->value.arr.n = n;
} else {
ok = ok && gguf_v3_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
}
switch (kv->value.arr.type) {
case GGUF_V3_TYPE_UINT8:
case GGUF_V3_TYPE_INT8:
case GGUF_V3_TYPE_UINT16:
case GGUF_V3_TYPE_INT16:
case GGUF_V3_TYPE_UINT32:
case GGUF_V3_TYPE_INT32:
case GGUF_V3_TYPE_FLOAT32:
case GGUF_V3_TYPE_UINT64:
case GGUF_V3_TYPE_INT64:
case GGUF_V3_TYPE_FLOAT64:
case GGUF_V3_TYPE_BOOL:
{
kv->value.arr.data = malloc(kv->value.arr.n * GGUF_V3_TYPE_SIZE[kv->value.arr.type]);
ok = ok && gguf_v3_fread_el(file, kv->value.arr.data, kv->value.arr.n * GGUF_V3_TYPE_SIZE[kv->value.arr.type], &offset);
} break;
case GGUF_V3_TYPE_STRING:
{
kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_v3_str));
for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
ok = ok && gguf_v3_fread_str(file, &((struct gguf_v3_str *) kv->value.arr.data)[j], &offset);
}
} break;
case GGUF_V3_TYPE_ARRAY:
case GGUF_V3_TYPE_COUNT: GGML_V3_ASSERT(false && "invalid type"); break;
}
} break;
case GGUF_V3_TYPE_COUNT: GGML_V3_ASSERT(false && "invalid type");
}
if (!ok) {
break;
}
}
if (!ok) {
fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
fclose(file);
gguf_v3_free(ctx);
return NULL;
}
}
// read the tensor infos
{
ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct gguf_v3_tensor_info));
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
struct gguf_v3_tensor_info * info = &ctx->infos[i];
for (int j = 0; j < GGML_V3_MAX_DIMS; ++j) {
info->ne[j] = 1;
}
ok = ok && gguf_v3_fread_str(file, &info->name, &offset);
ok = ok && gguf_v3_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset);
for (uint32_t j = 0; j < info->n_dims; ++j) {
if (ctx->header.version == 1) {
// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
uint32_t t = 0;
ok = ok && gguf_v3_fread_el(file, &t, sizeof(t), &offset);
info->ne[j] = t;
} else {
ok = ok && gguf_v3_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
}
}
ok = ok && gguf_v3_fread_el (file, &info->type, sizeof(info->type), &offset);
ok = ok && gguf_v3_fread_el (file, &info->offset, sizeof(info->offset), &offset);
if (!ok) {
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
fclose(file);
gguf_v3_free(ctx);
return NULL;
}
}
}
ctx->alignment = GGUF_V3_DEFAULT_ALIGNMENT;
int alignment_idx = gguf_v3_find_key(ctx, "general.alignment");
if (alignment_idx != -1) {
ctx->alignment = gguf_v3_get_val_u32(ctx, alignment_idx);
}
// we require the data section to be aligned, so take into account any padding
{
const size_t offset_pad = offset % ctx->alignment;
if (offset_pad != 0) {
offset += ctx->alignment - offset_pad;
fseek(file, offset, SEEK_SET);
}
}
// store the current file offset - this is where the data section starts
ctx->offset = offset;
// compute the total size of the data section, taking into account the alignment
{
ctx->size = 0;
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
struct gguf_v3_tensor_info * info = &ctx->infos[i];
const int64_t ne =
(int64_t) info->ne[0] *
(int64_t) info->ne[1] *
(int64_t) info->ne[2] *
(int64_t) info->ne[3];
if (ne % ggml_v3_blck_size(info->type) != 0) {
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
__func__, info->name.data, (int)info->type, ggml_v3_type_name(info->type), ne, ggml_v3_blck_size(info->type));
fclose(file);
gguf_v3_free(ctx);
return NULL;
}
const size_t size_cur = ggml_v3_row_size(info->type, ne);
ctx->size += GGML_V3_PAD(size_cur, ctx->alignment);
}
}
// load the tensor data only if requested
if (params.ctx != NULL) {
// if the provided gguf_v3_context is no_alloc, then we create "empty" tensors and do not read the binary blob
// otherwise, we load the binary blob into the created ggml_v3_context as well, and point the "data" members of
// the ggml_v3_tensor structs to the appropriate locations in the binary blob
// compute the exact size needed for the new ggml_v3_context
const size_t mem_size =
params.no_alloc ?
(ctx->header.n_tensors )*ggml_v3_tensor_overhead() :
(ctx->header.n_tensors + 1)*ggml_v3_tensor_overhead() + ctx->size;
struct ggml_v3_init_params pdata = {
.mem_size = mem_size,
.mem_buffer = NULL,
.no_alloc = params.no_alloc,
};
*params.ctx = ggml_v3_init(pdata);
struct ggml_v3_context * ctx_data = *params.ctx;
struct ggml_v3_tensor * data = NULL;
if (!params.no_alloc) {
data = ggml_v3_new_tensor_1d(ctx_data, GGML_V3_TYPE_I8, ctx->size);
ok = ok && data != NULL;
// read the binary blob with the tensor data
ok = ok && gguf_v3_fread_el(file, data->data, ctx->size, &offset);
if (!ok) {
fprintf(stderr, "%s: failed to read tensor data\n", __func__);
fclose(file);
ggml_v3_free(ctx_data);
gguf_v3_free(ctx);
return NULL;
}
ctx->data = data->data;
}
ggml_v3_set_no_alloc(ctx_data, true);
// create the tensors
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
const int64_t ne[GGML_V3_MAX_DIMS] = {
ctx->infos[i].ne[0],
ctx->infos[i].ne[1],
ctx->infos[i].ne[2],
ctx->infos[i].ne[3],
};
struct ggml_v3_tensor * cur = ggml_v3_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
ok = ok && cur != NULL;
ggml_v3_set_name(cur, ctx->infos[i].name.data);
if (!ok) {
break;
}
// point the data member to the appropriate location in the binary blob using the tensor infos
if (!params.no_alloc) {
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
}
}
if (!ok) {
fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
fclose(file);
ggml_v3_free(ctx_data);
gguf_v3_free(ctx);
return NULL;
}
ggml_v3_set_no_alloc(ctx_data, params.no_alloc);
}
fclose(file);
return ctx;
}
void gguf_v3_free(struct gguf_v3_context * ctx) {
if (ctx == NULL) {
return;
}
if (ctx->kv) {
// free string memory - not great..
for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
struct gguf_v3_kv * kv = &ctx->kv[i];
if (kv->key.data) {
free(kv->key.data);
}
if (kv->type == GGUF_V3_TYPE_STRING) {
if (kv->value.str.data) {
free(kv->value.str.data);
}
}
if (kv->type == GGUF_V3_TYPE_ARRAY) {
if (kv->value.arr.data) {
if (kv->value.arr.type == GGUF_V3_TYPE_STRING) {
for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
struct gguf_v3_str * str = &((struct gguf_v3_str *) kv->value.arr.data)[j];
if (str->data) {
free(str->data);
}
}
}
free(kv->value.arr.data);
}
}
}
free(ctx->kv);
}
if (ctx->infos) {
for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
struct gguf_v3_tensor_info * info = &ctx->infos[i];
if (info->name.data) {
free(info->name.data);
}
}
free(ctx->infos);
}
GGML_V3_ALIGNED_FREE(ctx);
}
const char * gguf_v3_type_name(enum gguf_v3_type type) {
return GGUF_V3_TYPE_NAME[type];
}
int gguf_v3_get_version(const struct gguf_v3_context * ctx) {
return ctx->header.version;
}
size_t gguf_v3_get_alignment(const struct gguf_v3_context * ctx) {
return ctx->alignment;
}
size_t gguf_v3_get_data_offset(const struct gguf_v3_context * ctx) {
return ctx->offset;
}
void * gguf_v3_get_data(const struct gguf_v3_context * ctx) {
return ctx->data;
}
int gguf_v3_get_n_kv(const struct gguf_v3_context * ctx) {
return ctx->header.n_kv;
}
int gguf_v3_find_key(const struct gguf_v3_context * ctx, const char * key) {
// return -1 if key not found
int keyfound = -1;
const int n_kv = gguf_v3_get_n_kv(ctx);
for (int i = 0; i < n_kv; ++i) {
if (strcmp(key, gguf_v3_get_key(ctx, i)) == 0) {
keyfound = i;
break;
}
}
return keyfound;
}
const char * gguf_v3_get_key(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
return ctx->kv[key_id].key.data;
}
enum gguf_v3_type gguf_v3_get_kv_type(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
return ctx->kv[key_id].type;
}
enum gguf_v3_type gguf_v3_get_arr_type(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.type;
}
const void * gguf_v3_get_arr_data(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.data;
}
const char * gguf_v3_get_arr_str(const struct gguf_v3_context * ctx, int key_id, int i) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_ARRAY);
struct gguf_v3_kv * kv = &ctx->kv[key_id];
struct gguf_v3_str * str = &((struct gguf_v3_str *) kv->value.arr.data)[i];
return str->data;
}
int gguf_v3_get_arr_n(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_ARRAY);
return ctx->kv[key_id].value.arr.n;
}
uint8_t gguf_v3_get_val_u8(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_UINT8);
return ctx->kv[key_id].value.uint8;
}
int8_t gguf_v3_get_val_i8(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_INT8);
return ctx->kv[key_id].value.int8;
}
uint16_t gguf_v3_get_val_u16(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_UINT16);
return ctx->kv[key_id].value.uint16;
}
int16_t gguf_v3_get_val_i16(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_INT16);
return ctx->kv[key_id].value.int16;
}
uint32_t gguf_v3_get_val_u32(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_UINT32);
return ctx->kv[key_id].value.uint32;
}
int32_t gguf_v3_get_val_i32(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_INT32);
return ctx->kv[key_id].value.int32;
}
float gguf_v3_get_val_f32(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_FLOAT32);
return ctx->kv[key_id].value.float32;
}
uint64_t gguf_v3_get_val_u64(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_UINT64);
return ctx->kv[key_id].value.uint64;
}
int64_t gguf_v3_get_val_i64(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_INT64);
return ctx->kv[key_id].value.int64;
}
double gguf_v3_get_val_f64(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_FLOAT64);
return ctx->kv[key_id].value.float64;
}
bool gguf_v3_get_val_bool(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_BOOL);
return ctx->kv[key_id].value.bool_;
}
const char * gguf_v3_get_val_str(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type == GGUF_V3_TYPE_STRING);
return ctx->kv[key_id].value.str.data;
}
const void * gguf_v3_get_val_data(const struct gguf_v3_context * ctx, int key_id) {
GGML_V3_ASSERT(key_id >= 0 && key_id < gguf_v3_get_n_kv(ctx));
GGML_V3_ASSERT(ctx->kv[key_id].type != GGUF_V3_TYPE_ARRAY);
GGML_V3_ASSERT(ctx->kv[key_id].type != GGUF_V3_TYPE_STRING);
return &ctx->kv[key_id].value;
}
int gguf_v3_get_n_tensors(const struct gguf_v3_context * ctx) {
return ctx->header.n_tensors;
}
int gguf_v3_find_tensor(const struct gguf_v3_context * ctx, const char * name) {
// return -1 if tensor not found
int tensorfound = -1;
const int n_tensors = gguf_v3_get_n_tensors(ctx);
for (int i = 0; i < n_tensors; ++i) {
if (strcmp(name, gguf_v3_get_tensor_name(ctx, i)) == 0) {
tensorfound = i;
break;
}
}
return tensorfound;
}
size_t gguf_v3_get_tensor_offset(const struct gguf_v3_context * ctx, int i) {
return ctx->infos[i].offset;
}
char * gguf_v3_get_tensor_name(const struct gguf_v3_context * ctx, int i) {
return ctx->infos[i].name.data;
}
enum ggml_v3_type gguf_v3_get_tensor_type(const struct gguf_v3_context * ctx, int i) {
return ctx->infos[i].type;
}
// returns the index
static int gguf_v3_get_or_add_key(struct gguf_v3_context * ctx, const char * key) {
const int idx = gguf_v3_find_key(ctx, key);
if (idx >= 0) {
return idx;
}
const int n_kv = gguf_v3_get_n_kv(ctx);
ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_v3_kv));
ctx->kv[n_kv].key.n = strlen(key);
ctx->kv[n_kv].key.data = strdup(key);
ctx->header.n_kv++;
return n_kv;
}
void gguf_v3_set_val_u8(struct gguf_v3_context * ctx, const char * key, uint8_t val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_UINT8;
ctx->kv[idx].value.uint8 = val;
}
void gguf_v3_set_val_i8(struct gguf_v3_context * ctx, const char * key, int8_t val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_INT8;
ctx->kv[idx].value.int8 = val;
}
void gguf_v3_set_val_u16(struct gguf_v3_context * ctx, const char * key, uint16_t val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_UINT16;
ctx->kv[idx].value.uint16 = val;
}
void gguf_v3_set_val_i16(struct gguf_v3_context * ctx, const char * key, int16_t val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_INT16;
ctx->kv[idx].value.int16 = val;
}
void gguf_v3_set_val_u32(struct gguf_v3_context * ctx, const char * key, uint32_t val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_UINT32;
ctx->kv[idx].value.uint32 = val;
}
void gguf_v3_set_val_i32(struct gguf_v3_context * ctx, const char * key, int32_t val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_INT32;
ctx->kv[idx].value.int32 = val;
}
void gguf_v3_set_val_f32(struct gguf_v3_context * ctx, const char * key, float val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_FLOAT32;
ctx->kv[idx].value.float32 = val;
}
void gguf_v3_set_val_u64(struct gguf_v3_context * ctx, const char * key, uint64_t val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_UINT64;
ctx->kv[idx].value.uint64 = val;
}
void gguf_v3_set_val_i64(struct gguf_v3_context * ctx, const char * key, int64_t val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_INT64;
ctx->kv[idx].value.int64 = val;
}
void gguf_v3_set_val_f64(struct gguf_v3_context * ctx, const char * key, double val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_FLOAT64;
ctx->kv[idx].value.float64 = val;
}
void gguf_v3_set_val_bool(struct gguf_v3_context * ctx, const char * key, bool val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_BOOL;
ctx->kv[idx].value.bool_ = val;
}
void gguf_v3_set_val_str(struct gguf_v3_context * ctx, const char * key, const char * val) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_STRING;
ctx->kv[idx].value.str.n = strlen(val);
ctx->kv[idx].value.str.data = strdup(val);
}
void gguf_v3_set_arr_data(struct gguf_v3_context * ctx, const char * key, enum gguf_v3_type type, const void * data, int n) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_ARRAY;
ctx->kv[idx].value.arr.type = type;
ctx->kv[idx].value.arr.n = n;
ctx->kv[idx].value.arr.data = malloc(n*GGUF_V3_TYPE_SIZE[type]);
memcpy(ctx->kv[idx].value.arr.data, data, n*GGUF_V3_TYPE_SIZE[type]);
}
void gguf_v3_set_arr_str(struct gguf_v3_context * ctx, const char * key, const char ** data, int n) {
const int idx = gguf_v3_get_or_add_key(ctx, key);
ctx->kv[idx].type = GGUF_V3_TYPE_ARRAY;
ctx->kv[idx].value.arr.type = GGUF_V3_TYPE_STRING;
ctx->kv[idx].value.arr.n = n;
ctx->kv[idx].value.arr.data = malloc(n*sizeof(struct gguf_v3_str));
for (int i = 0; i < n; i++) {
struct gguf_v3_str * str = &((struct gguf_v3_str *)ctx->kv[idx].value.arr.data)[i];
str->n = strlen(data[i]);
str->data = strdup(data[i]);
}
}
// set or add KV pairs from another context
void gguf_v3_set_kv(struct gguf_v3_context * ctx, struct gguf_v3_context * src) {
for (uint32_t i = 0; i < src->header.n_kv; i++) {
switch (src->kv[i].type) {
case GGUF_V3_TYPE_UINT8: gguf_v3_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break;
case GGUF_V3_TYPE_INT8: gguf_v3_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break;
case GGUF_V3_TYPE_UINT16: gguf_v3_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break;
case GGUF_V3_TYPE_INT16: gguf_v3_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break;
case GGUF_V3_TYPE_UINT32: gguf_v3_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break;
case GGUF_V3_TYPE_INT32: gguf_v3_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break;
case GGUF_V3_TYPE_FLOAT32: gguf_v3_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break;
case GGUF_V3_TYPE_UINT64: gguf_v3_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break;
case GGUF_V3_TYPE_INT64: gguf_v3_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break;
case GGUF_V3_TYPE_FLOAT64: gguf_v3_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break;
case GGUF_V3_TYPE_BOOL: gguf_v3_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break;
case GGUF_V3_TYPE_STRING: gguf_v3_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
case GGUF_V3_TYPE_ARRAY:
{
if (src->kv[i].value.arr.type == GGUF_V3_TYPE_STRING) {
const char ** data = malloc(src->kv[i].value.arr.n*sizeof(char *));
for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
data[j] = ((struct gguf_v3_str *)src->kv[i].value.arr.data)[j].data;
}
gguf_v3_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
free((void *)data);
} else if (src->kv[i].value.arr.type == GGUF_V3_TYPE_ARRAY) {
GGML_V3_ASSERT(false && "nested arrays not supported");
} else {
gguf_v3_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
}
} break;
case GGUF_V3_TYPE_COUNT: GGML_V3_ASSERT(false && "invalid type"); break;
}
}
}
void gguf_v3_add_tensor(
struct gguf_v3_context * ctx,
const struct ggml_v3_tensor * tensor) {
const int idx = ctx->header.n_tensors;
ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_v3_tensor_info));
ctx->infos[idx].name.n = strlen(tensor->name);
ctx->infos[idx].name.data = strdup(tensor->name);
for (int i = 0; i < GGML_V3_MAX_DIMS; ++i) {
ctx->infos[idx].ne[i] = 1;
}
ctx->infos[idx].n_dims = ggml_v3_n_dims(tensor);
for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
ctx->infos[idx].ne[i] = tensor->ne[i];
}
ctx->infos[idx].type = tensor->type;
ctx->infos[idx].offset = 0;
ctx->infos[idx].data = tensor->data;
ctx->infos[idx].size = ggml_v3_nbytes(tensor);
if (ctx->header.n_tensors > 0) {
ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_V3_PAD(ctx->infos[idx - 1].size, ctx->alignment);
}
ctx->header.n_tensors++;
}
void gguf_v3_set_tensor_type(struct gguf_v3_context * ctx, const char * name, enum ggml_v3_type type) {
const int idx = gguf_v3_find_tensor(ctx, name);
if (idx < 0) {
GGML_V3_ASSERT(false && "tensor not found");
}
ctx->infos[idx].type = type;
}
void gguf_v3_set_tensor_data(struct gguf_v3_context * ctx, const char * name, const void * data, size_t size) {
const int idx = gguf_v3_find_tensor(ctx, name);
if (idx < 0) {
GGML_V3_ASSERT(false && "tensor not found");
}
ctx->infos[idx].data = data;
ctx->infos[idx].size = size;
// update offsets
for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_V3_PAD(ctx->infos[i - 1].size, ctx->alignment);
}
}
//static void gguf_v3_fwrite_str(FILE * file, const struct gguf_v3_str * val) {
// fwrite(&val->n, sizeof(val->n), 1, file);
// fwrite(val->data, sizeof(char), val->n, file);
//}
//
//static void gguf_v3_fwrite_el(FILE * file, const void * val, size_t size) {
// fwrite(val, sizeof(char), size, file);
//}
struct gguf_v3_buf {
void * data;
size_t size;
size_t offset;
};
static struct gguf_v3_buf gguf_v3_buf_init(size_t size) {
struct gguf_v3_buf buf = {
/*buf.data =*/ size == 0 ? NULL : malloc(size),
/*buf.size =*/ size,
/*buf.offset =*/ 0,
};
return buf;
}
static void gguf_v3_buf_free(struct gguf_v3_buf buf) {
if (buf.data) {
free(buf.data);
}
}
static void gguf_v3_buf_grow(struct gguf_v3_buf * buf, size_t size) {
if (buf->offset + size > buf->size) {
buf->size = 1.5*(buf->offset + size);
if (buf->data) {
buf->data = realloc(buf->data, buf->size);
}
}
}
static void gguf_v3_bwrite_str(struct gguf_v3_buf * buf, const struct gguf_v3_str * val) {
gguf_v3_buf_grow(buf, sizeof(val->n) + val->n);
if (buf->data) {
memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
}
buf->offset += sizeof(val->n);
if (buf->data) {
memcpy((char *) buf->data + buf->offset, val->data, val->n);
}
buf->offset += val->n;
}
static void gguf_v3_bwrite_el(struct gguf_v3_buf * buf, const void * val, size_t el_size) {
gguf_v3_buf_grow(buf, el_size);
if (buf->data) {
memcpy((char *) buf->data + buf->offset, val, el_size);
}
buf->offset += el_size;
}
static void gguf_v3_write_to_buf(const struct gguf_v3_context * ctx, struct gguf_v3_buf * buf, bool only_meta) {
// write header
gguf_v3_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
gguf_v3_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
gguf_v3_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
gguf_v3_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv));
// write key-value pairs
for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
struct gguf_v3_kv * kv = &ctx->kv[i];
gguf_v3_bwrite_str(buf, &kv->key);
gguf_v3_bwrite_el (buf, &kv->type, sizeof(kv->type));
switch (kv->type) {
case GGUF_V3_TYPE_UINT8: gguf_v3_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break;
case GGUF_V3_TYPE_INT8: gguf_v3_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break;
case GGUF_V3_TYPE_UINT16: gguf_v3_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break;
case GGUF_V3_TYPE_INT16: gguf_v3_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break;
case GGUF_V3_TYPE_UINT32: gguf_v3_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break;
case GGUF_V3_TYPE_INT32: gguf_v3_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break;
case GGUF_V3_TYPE_FLOAT32: gguf_v3_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
case GGUF_V3_TYPE_UINT64: gguf_v3_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break;
case GGUF_V3_TYPE_INT64: gguf_v3_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break;
case GGUF_V3_TYPE_FLOAT64: gguf_v3_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
case GGUF_V3_TYPE_BOOL: gguf_v3_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break;
case GGUF_V3_TYPE_STRING: gguf_v3_bwrite_str(buf, &kv->value.str ); break;
case GGUF_V3_TYPE_ARRAY:
{
gguf_v3_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
gguf_v3_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) );
switch (kv->value.arr.type) {
case GGUF_V3_TYPE_UINT8:
case GGUF_V3_TYPE_INT8:
case GGUF_V3_TYPE_UINT16:
case GGUF_V3_TYPE_INT16:
case GGUF_V3_TYPE_UINT32:
case GGUF_V3_TYPE_INT32:
case GGUF_V3_TYPE_FLOAT32:
case GGUF_V3_TYPE_UINT64:
case GGUF_V3_TYPE_INT64:
case GGUF_V3_TYPE_FLOAT64:
case GGUF_V3_TYPE_BOOL:
{
gguf_v3_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * GGUF_V3_TYPE_SIZE[kv->value.arr.type]);
} break;
case GGUF_V3_TYPE_STRING:
{
for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
gguf_v3_bwrite_str(buf, &((struct gguf_v3_str *) kv->value.arr.data)[j]);
}
} break;
case GGUF_V3_TYPE_ARRAY:
case GGUF_V3_TYPE_COUNT: GGML_V3_ASSERT(false && "invalid type"); break;
}
} break;
case GGUF_V3_TYPE_COUNT: GGML_V3_ASSERT(false && "invalid type");
}
}
// write tensor infos
for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
struct gguf_v3_tensor_info * info = &ctx->infos[i];
gguf_v3_bwrite_str(buf, &info->name);
gguf_v3_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
for (uint32_t j = 0; j < info->n_dims; ++j) {
gguf_v3_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
}
gguf_v3_bwrite_el(buf, &info->type, sizeof(info->type));
gguf_v3_bwrite_el(buf, &info->offset, sizeof(info->offset));
}
// we require the data section to be aligned, so take into account any padding
{
const size_t offset = buf->offset;
const size_t offset_pad = GGML_V3_PAD(offset, ctx->alignment);
if (offset_pad != offset) {
uint8_t pad = 0;
for (size_t i = 0; i < offset_pad - offset; ++i) {
gguf_v3_bwrite_el(buf, &pad, sizeof(pad));
}
}
}
if (only_meta) {
return;
}
size_t offset = 0;
// write tensor data
for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
struct gguf_v3_tensor_info * info = &ctx->infos[i];
const size_t size = info->size;
const size_t size_pad = GGML_V3_PAD(size, ctx->alignment);
gguf_v3_bwrite_el(buf, info->data, size);
if (size_pad != size) {
uint8_t pad = 0;
for (size_t j = 0; j < size_pad - size; ++j) {
gguf_v3_bwrite_el(buf, &pad, sizeof(pad));
}
}
GGML_V3_ASSERT(offset == info->offset);
offset += size_pad;
}
}
void gguf_v3_write_to_file(const struct gguf_v3_context * ctx, const char * fname, bool only_meta) {
FILE * file = fopen(fname, "wb");
if (!file) {
GGML_V3_ASSERT(false && "failed to open file for writing");
}
struct gguf_v3_buf buf = gguf_v3_buf_init(16*1024);
gguf_v3_write_to_buf(ctx, &buf, only_meta);
fwrite(buf.data, 1, buf.offset, file);
gguf_v3_buf_free(buf);
fclose(file);
}
size_t gguf_v3_get_meta_size(const struct gguf_v3_context * ctx) {
// no allocs - only compute size
struct gguf_v3_buf buf = gguf_v3_buf_init(0);
gguf_v3_write_to_buf(ctx, &buf, true);
return buf.offset;
}
void gguf_v3_get_meta_data(const struct gguf_v3_context * ctx, void * data) {
struct gguf_v3_buf buf = gguf_v3_buf_init(16*1024);
gguf_v3_write_to_buf(ctx, &buf, true);
memcpy(data, buf.data, buf.offset);
gguf_v3_buf_free(buf);
}
////////////////////////////////////////////////////////////////////////////////
int ggml_v3_cpu_has_avx(void) {
#if defined(__AVX__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_avx_vnni(void) {
#if defined(__AVXVNNI__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_avx2(void) {
#if defined(__AVX2__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_avx512(void) {
#if defined(__AVX512F__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_avx512_vbmi(void) {
#if defined(__AVX512VBMI__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_avx512_vnni(void) {
#if defined(__AVX512VNNI__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_fma(void) {
#if defined(__FMA__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_neon(void) {
#if defined(__ARM_NEON)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_arm_fma(void) {
#if defined(__ARM_FEATURE_FMA)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_metal(void) {
#if defined(GGML_USE_METAL)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_f16c(void) {
#if defined(__F16C__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_fp16_va(void) {
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_wasm_simd(void) {
#if defined(__wasm_simd128__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_blas(void) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_CLBLAST)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_cublas(void) {
#if defined(GGML_USE_CUDA)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_clblast(void) {
#if defined(GGML_USE_CLBLAST)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_gpublas(void) {
return ggml_v3_cpu_has_cublas() || ggml_v3_cpu_has_clblast();
}
int ggml_v3_cpu_has_sse3(void) {
#if defined(__SSE3__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_ssse3(void) {
#if defined(__SSSE3__)
return 1;
#else
return 0;
#endif
}
int ggml_v3_cpu_has_vsx(void) {
#if defined(__POWER9_VECTOR__)
return 1;
#else
return 0;
#endif
}
////////////////////////////////////////////////////////////////////////////////
//formerly ggml-quants.c
#include <math.h>
#include <string.h>
#include <assert.h>
#include <float.h>
#ifdef __ARM_NEON
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
//
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
//
#include <arm_neon.h>
#else
#ifdef __wasm_simd128__
#include <wasm_simd128.h>
#else
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
#include <altivec.h>
#undef bool
#define bool _Bool
#else
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#else
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
#if !defined(__riscv)
#include <immintrin.h>
#endif
#endif
#endif
#endif
#endif
#endif
#ifdef __riscv_v_intrinsic
#include <riscv_vector.h>
#endif
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#ifndef MM256_SET_M128I
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#endif
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(x, x);
// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(y, x);
// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);
const __m128i ones = _mm_set1_epi16(1);
return _mm_madd_epi16(ones, dot);
}
#if __AVX__ || __AVX2__ || __AVX512F__
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = _mm256_extractf128_ps(x, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
return _mm_cvtss_f32(res);
}
// horizontally add 8 int32_t
static inline int hsum_i32_8(const __m256i a) {
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
// horizontally add 4 int32_t
static inline int hsum_i32_4(const __m128i a) {
const __m128i hi64 = _mm_unpackhi_epi64(a, a);
const __m128i sum64 = _mm_add_epi32(hi64, a);
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
#if defined(__AVX2__) || defined(__AVX512F__)
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m256i shuf_mask = _mm256_set_epi64x(
0x0303030303030303, 0x0202020202020202,
0x0101010101010101, 0x0000000000000000);
__m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
bytes = _mm256_or_si256(bytes, bit_mask);
return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
}
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
const __m256i lowMask = _mm256_set1_epi8( 0xF );
return _mm256_and_si256(lowMask, bytes);
}
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m256i x) {
const __m256i ones = _mm256_set1_epi16(1);
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
return _mm256_cvtepi32_ps(summed_pairs);
}
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
#if __AVXVNNI__
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
return sum_i16_pairs_float(dot);
#endif
}
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
#if __AVXVNNIINT8__
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
return _mm256_cvtepi32_ps(summed_pairs);
#else
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(x, x);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(y, x);
return mul_sum_us8_pairs_float(ax, sy);
#endif
}
static inline __m128i packNibbles( __m256i bytes )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
#if __AVX512F__
const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
#else
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
__m256i high = _mm256_andnot_si256( lowByte, bytes );
__m256i low = _mm256_and_si256( lowByte, bytes );
high = _mm256_srli_epi16( high, 4 );
bytes = _mm256_or_si256( low, high );
// Compress uint16_t lanes into bytes
__m128i r0 = _mm256_castsi256_si128( bytes );
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
return _mm_packus_epi16( r0, r1 );
#endif
}
#elif defined(__AVX__)
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
__m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
__m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
bytesl = _mm_or_si128(bytesl, bit_mask);
bytesh = _mm_or_si128(bytesh, bit_mask);
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
return MM256_SET_M128I(bytesh, bytesl);
}
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
// Load 16 bytes from memory
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
__m128i tmph = _mm_srli_epi16(tmpl, 4);
const __m128i lowMask = _mm_set1_epi8(0xF);
tmpl = _mm_and_si128(lowMask, tmpl);
tmph = _mm_and_si128(lowMask, tmph);
return MM256_SET_M128I(tmph, tmpl);
}
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
const __m128i ones = _mm_set1_epi16(1);
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
return _mm256_cvtepi32_ps(summed_pairs);
}
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
const __m128i axl = _mm256_castsi256_si128(ax);
const __m128i axh = _mm256_extractf128_si256(ax, 1);
const __m128i syl = _mm256_castsi256_si128(sy);
const __m128i syh = _mm256_extractf128_si256(sy, 1);
// Perform multiplication and create 16-bit values
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
const __m128i doth = _mm_maddubs_epi16(axh, syh);
return sum_i16_pairs_float(doth, dotl);
}
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
const __m128i xl = _mm256_castsi256_si128(x);
const __m128i xh = _mm256_extractf128_si256(x, 1);
const __m128i yl = _mm256_castsi256_si128(y);
const __m128i yh = _mm256_extractf128_si256(y, 1);
// Get absolute values of x vectors
const __m128i axl = _mm_sign_epi8(xl, xl);
const __m128i axh = _mm_sign_epi8(xh, xh);
// Sign the values of the y vectors
const __m128i syl = _mm_sign_epi8(yl, xl);
const __m128i syh = _mm_sign_epi8(yh, xh);
// Perform multiplication and create 16-bit values
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
const __m128i doth = _mm_maddubs_epi16(axh, syh);
return sum_i16_pairs_float(doth, dotl);
}
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m128i lowByte = _mm_set1_epi16( 0xFF );
__m128i high = _mm_andnot_si128( lowByte, bytes1 );
__m128i low = _mm_and_si128( lowByte, bytes1 );
high = _mm_srli_epi16( high, 4 );
bytes1 = _mm_or_si128( low, high );
high = _mm_andnot_si128( lowByte, bytes2 );
low = _mm_and_si128( lowByte, bytes2 );
high = _mm_srli_epi16( high, 4 );
bytes2 = _mm_or_si128( low, high );
return _mm_packus_epi16( bytes1, bytes2);
}
#endif
#elif defined(__SSSE3__)
// horizontally add 4x4 floats
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
__m128 res_0 =_mm_hadd_ps(a, b);
__m128 res_1 =_mm_hadd_ps(c, d);
__m128 res =_mm_hadd_ps(res_0, res_1);
res =_mm_hadd_ps(res, res);
res =_mm_hadd_ps(res, res);
return _mm_cvtss_f32(res);
}
#endif // __AVX__ || __AVX2__ || __AVX512F__
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#if defined(__ARM_NEON)
#if !defined(__aarch64__)
// 64-bit compatibility
// vaddvq_s16
// vpaddq_s16
// vaddvq_s32
// vaddvq_f32
// vmaxvq_f32
// vcvtnq_s32_f32
inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
}
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
return vcombine_s16(a0, b0);
}
inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
inline static float vmaxvq_f32(float32x4_t v) {
return
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
}
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
int32x4_t res;
res[0] = roundf(vgetq_lane_f32(v, 0));
res[1] = roundf(vgetq_lane_f32(v, 1));
res[2] = roundf(vgetq_lane_f32(v, 2));
res[3] = roundf(vgetq_lane_f32(v, 3));
return res;
}
// vld1q_s16_x2
// vld1q_u8_x2
// vld1q_u8_x4
// vld1q_s8_x2
// vld1q_s8_x4
// TODO: double-check these work correctly
typedef struct ggml_v3_int16x8x2_t {
int16x8_t val[2];
} ggml_v3_int16x8x2_t;
inline static ggml_v3_int16x8x2_t ggml_v3_vld1q_s16_x2(const int16_t * ptr) {
ggml_v3_int16x8x2_t res;
res.val[0] = vld1q_s16(ptr + 0);
res.val[1] = vld1q_s16(ptr + 8);
return res;
}
typedef struct ggml_v3_uint8x16x2_t {
uint8x16_t val[2];
} ggml_v3_uint8x16x2_t;
inline static ggml_v3_uint8x16x2_t ggml_v3_vld1q_u8_x2(const uint8_t * ptr) {
ggml_v3_uint8x16x2_t res;
res.val[0] = vld1q_u8(ptr + 0);
res.val[1] = vld1q_u8(ptr + 16);
return res;
}
typedef struct ggml_v3_uint8x16x4_t {
uint8x16_t val[4];
} ggml_v3_uint8x16x4_t;
inline static ggml_v3_uint8x16x4_t ggml_v3_vld1q_u8_x4(const uint8_t * ptr) {
ggml_v3_uint8x16x4_t res;
res.val[0] = vld1q_u8(ptr + 0);
res.val[1] = vld1q_u8(ptr + 16);
res.val[2] = vld1q_u8(ptr + 32);
res.val[3] = vld1q_u8(ptr + 48);
return res;
}
typedef struct ggml_v3_int8x16x2_t {
int8x16_t val[2];
} ggml_v3_int8x16x2_t;
inline static ggml_v3_int8x16x2_t ggml_v3_vld1q_s8_x2(const int8_t * ptr) {
ggml_v3_int8x16x2_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
return res;
}
typedef struct ggml_v3_int8x16x4_t {
int8x16_t val[4];
} ggml_v3_int8x16x4_t;
inline static ggml_v3_int8x16x4_t ggml_v3_vld1q_s8_x4(const int8_t * ptr) {
ggml_v3_int8x16x4_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
res.val[2] = vld1q_s8(ptr + 32);
res.val[3] = vld1q_s8(ptr + 48);
return res;
}
#else
#define ggml_v3_int16x8x2_t int16x8x2_t
#define ggml_v3_uint8x16x2_t uint8x16x2_t
#define ggml_v3_uint8x16x4_t uint8x16x4_t
#define ggml_v3_int8x16x2_t int8x16x2_t
#define ggml_v3_int8x16x4_t int8x16x4_t
#define ggml_v3_vld1q_s16_x2 vld1q_s16_x2
#define ggml_v3_vld1q_u8_x2 vld1q_u8_x2
#define ggml_v3_vld1q_u8_x4 vld1q_u8_x4
#define ggml_v3_vld1q_s8_x2 vld1q_s8_x2
#define ggml_v3_vld1q_s8_x4 vld1q_s8_x4
#endif
#if !defined(__ARM_FEATURE_DOTPROD)
inline static int32x4_t ggml_v3_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
}
#else
#define ggml_v3_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
#endif
#endif
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
// precomputed tables for expanding 8bits to 8 bytes:
static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
#endif
// reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
static const int qk = QK4_0;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}
const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
for (int j = 0; j < qk/2; ++j) {
const float x0 = x[i*qk + 0 + j]*id;
const float x1 = x[i*qk + qk/2 + j]*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
y[i].qs[j] = xi0;
y[i].qs[j] |= xi1 << 4;
}
}
}
static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
quantize_row_q4_0_reference(x, y, k);
}
static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
const int qk = QK4_1;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (v < min) min = v;
if (v > max) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
y[i].m = GGML_V3_FP32_TO_FP16(min);
for (int j = 0; j < qk/2; ++j) {
const float x0 = (x[i*qk + 0 + j] - min)*id;
const float x1 = (x[i*qk + qk/2 + j] - min)*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
y[i].qs[j] = xi0;
y[i].qs[j] |= xi1 << 4;
}
}
}
static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
quantize_row_q4_1_reference(x, y, k);
}
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
static const int qk = QK5_0;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}
const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
uint32_t qh = 0;
for (int j = 0; j < qk/2; ++j) {
const float x0 = x[i*qk + 0 + j]*id;
const float x1 = x[i*qk + qk/2 + j]*id;
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
// get the 5-th bit and store it in qh at the right position
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
}
memcpy(&y[i].qh, &qh, sizeof(qh));
}
}
static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
quantize_row_q5_0_reference(x, y, k);
}
static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
const int qk = QK5_1;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (v < min) min = v;
if (v > max) max = v;
}
const float d = (max - min) / ((1 << 5) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
y[i].m = GGML_V3_FP32_TO_FP16(min);
uint32_t qh = 0;
for (int j = 0; j < qk/2; ++j) {
const float x0 = (x[i*qk + 0 + j] - min)*id;
const float x1 = (x[i*qk + qk/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
// get the 5-th bit and store it in qh at the right position
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
}
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
}
}
static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
quantize_row_q5_1_reference(x, y, k);
}
// reference implementation for deterministic creation of model files
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = x[i*QK8_0 + j];
amax = MAX(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
for (int j = 0; j < QK8_0; ++j) {
const float x0 = x[i*QK8_0 + j]*id;
y[i].qs[j] = roundf(x0);
}
}
}
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
const float amax = vmaxvq_f32(amaxv[0]);
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
for (int j = 0; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) {
v128_t srcv [8];
v128_t asrcv[8];
v128_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
wasm_f32x4_extract_lane(amaxv[0], 1)),
MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
wasm_f32x4_extract_lane(amaxv[0], 3)));
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
for (int j = 0; j < 8; j++) {
const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
}
}
#elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
__m256 v2 = _mm256_loadu_ps( x + 16 );
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
// Quantize these floats
const float d = maxScalar / 127.f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier
v0 = _mm256_mul_ps( v0, mul );
v1 = _mm256_mul_ps( v1, mul );
v2 = _mm256_mul_ps( v2, mul );
v3 = _mm256_mul_ps( v3, mul );
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
#if defined(__AVX2__)
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
__m128i ni0 = _mm256_castsi256_si128( i0 );
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
__m128i ni2 = _mm256_castsi256_si128( i1 );
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
__m128i ni4 = _mm256_castsi256_si128( i2 );
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
__m128i ni6 = _mm256_castsi256_si128( i3 );
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
// Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 );
ni2 = _mm_packs_epi32( ni2, ni3 );
ni4 = _mm_packs_epi32( ni4, ni5 );
ni6 = _mm_packs_epi32( ni6, ni7 );
// Convert int16 to int8
ni0 = _mm_packs_epi16( ni0, ni2 );
ni4 = _mm_packs_epi16( ni4, ni6 );
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
#endif
}
#elif defined(__riscv_v_intrinsic)
size_t vl = __riscv_vsetvl_e32m4(QK8_0);
for (int i = 0; i < nb; i++) {
// load elements
vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_V3_FP32_TO_FP16(d);
vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
// convert to integer
vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
// store result
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
}
#else
GGML_V3_UNUSED(nb);
// scalar
quantize_row_q8_0_reference(x, y, k);
#endif
}
// reference implementation for deterministic creation of model files
static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
assert(QK8_1 == 32);
assert(k % QK8_1 == 0);
const int nb = k / QK8_1;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_1; j++) {
const float v = x[i*QK8_1 + j];
amax = MAX(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
int sum = 0;
for (int j = 0; j < QK8_1/2; ++j) {
const float v0 = x[i*QK8_1 + j]*id;
const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
y[i].qs[ j] = roundf(v0);
y[i].qs[QK8_1/2 + j] = roundf(v1);
sum += y[i].qs[ j];
sum += y[i].qs[QK8_1/2 + j];
}
y[i].s = sum*d;
}
}
static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
assert(k % QK8_1 == 0);
const int nb = k / QK8_1;
block_q8_1 * restrict y = vy;
#if defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
const float amax = vmaxvq_f32(amaxv[0]);
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
int32x4_t accv = vdupq_n_s32(0);
for (int j = 0; j < 8; j++) {
const float32x4_t v = vmulq_n_f32(srcv[j], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
accv = vaddq_s32(accv, vi);
}
y[i].s = d * vaddvq_s32(accv);
}
#elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) {
v128_t srcv [8];
v128_t asrcv[8];
v128_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
wasm_f32x4_extract_lane(amaxv[0], 1)),
MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
wasm_f32x4_extract_lane(amaxv[0], 3)));
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
v128_t accv = wasm_i32x4_splat(0);
for (int j = 0; j < 8; j++) {
const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
accv = wasm_i32x4_add(accv, vi);
}
y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
wasm_i32x4_extract_lane(accv, 1) +
wasm_i32x4_extract_lane(accv, 2) +
wasm_i32x4_extract_lane(accv, 3));
}
#elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
__m256 v2 = _mm256_loadu_ps( x + 16 );
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
// Quantize these floats
const float d = maxScalar / 127.f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier
v0 = _mm256_mul_ps( v0, mul );
v1 = _mm256_mul_ps( v1, mul );
v2 = _mm256_mul_ps( v2, mul );
v3 = _mm256_mul_ps( v3, mul );
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
#if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
__m128i ni0 = _mm256_castsi256_si128( i0 );
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
__m128i ni2 = _mm256_castsi256_si128( i1 );
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
__m128i ni4 = _mm256_castsi256_si128( i2 );
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
__m128i ni6 = _mm256_castsi256_si128( i3 );
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
// Compute the sum of the quants and set y[i].s
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
// Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 );
ni2 = _mm_packs_epi32( ni2, ni3 );
ni4 = _mm_packs_epi32( ni4, ni5 );
ni6 = _mm_packs_epi32( ni6, ni7 );
// Convert int16 to int8
ni0 = _mm_packs_epi16( ni0, ni2 );
ni4 = _mm_packs_epi16( ni4, ni6 );
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
#endif
}
#elif defined(__riscv_v_intrinsic)
size_t vl = __riscv_vsetvl_e32m4(QK8_1);
for (int i = 0; i < nb; i++) {
// load elements
vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
// convert to integer
vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
// store result
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
// compute sum for y[i].s
vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
// set y[i].s
int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
y[i].s = sum*d;
}
#else
GGML_V3_UNUSED(nb);
// scalar
quantize_row_q8_1_reference(x, y, k);
#endif
}
static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
static const int qk = QK4_0;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0x0F) - 8;
const int x1 = (x[i].qs[j] >> 4) - 8;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
static const int qk = QK4_1;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
const float m = GGML_V3_FP16_TO_FP32(x[i].m);
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0x0F);
const int x1 = (x[i].qs[j] >> 4);
y[i*qk + j + 0 ] = x0*d + m;
y[i*qk + j + qk/2] = x1*d + m;
}
}
}
static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
static const int qk = QK5_0;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
static const int qk = QK5_1;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
const float m = GGML_V3_FP16_TO_FP32(x[i].m);
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
const int x1 = (x[i].qs[j] >> 4) | xh_1;
y[i*qk + j + 0 ] = x0*d + m;
y[i*qk + j + qk/2] = x1*d + m;
}
}
}
static void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
static const int qk = QK8_0;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
for (int j = 0; j < qk; ++j) {
y[i*qk + j] = x[i].qs[j]*d;
}
}
}
//
// 2-6 bit quantization in super-blocks
//
//
// ===================== Helper functions
//
static inline int nearest_int(float fval) {
assert(fval <= 4194303.f);
float val = fval + 12582912.f;
int i; memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
}
static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
float max = 0;
float amax = 0;
for (int i = 0; i < n; ++i) {
float ax = fabsf(x[i]);
if (ax > amax) { amax = ax; max = x[i]; }
}
if (amax < 1e-30f) { // all zero
for (int i = 0; i < n; ++i) {
L[i] = 0;
}
return 0.f;
}
float iscale = -nmax / max;
if (rmse_type == 0) {
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
}
return 1/iscale;
}
bool return_early = false;
if (rmse_type < 0) {
rmse_type = -rmse_type;
return_early = true;
}
int weight_type = rmse_type%2;
float sumlx = 0;
float suml2 = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
l = MAX(-nmax, MIN(nmax-1, l));
L[i] = l + nmax;
float w = weight_type == 1 ? x[i] * x[i] : 1;
sumlx += w*x[i]*l;
suml2 += w*l*l;
}
float scale = sumlx/suml2;
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
float best = scale * sumlx;
for (int is = -9; is <= 9; ++is) {
if (is == 0) {
continue;
}
iscale = -(nmax + 0.1f*is) / max;
sumlx = suml2 = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
l = MAX(-nmax, MIN(nmax-1, l));
float w = weight_type == 1 ? x[i] * x[i] : 1;
sumlx += w*x[i]*l;
suml2 += w*l*l;
}
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
}
scale = sumlx/suml2; best = scale*sumlx;
}
}
return scale;
}
static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
float max = 0;
float amax = 0;
for (int i = 0; i < n; ++i) {
float ax = fabsf(x[i]);
if (ax > amax) { amax = ax; max = x[i]; }
}
if (!amax) { // all zero
for (int i = 0; i < n; ++i) { L[i] = 0; }
return 0.f;
}
float iscale = -nmax / max;
if (do_rmse) {
float sumlx = 0;
float suml2 = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
l = MAX(-nmax, MIN(nmax-1, l));
L[i] = l;
float w = x[i]*x[i];
sumlx += w*x[i]*l;
suml2 += w*l*l;
}
for (int itry = 0; itry < 5; ++itry) {
int n_changed = 0;
for (int i = 0; i < n; ++i) {
float w = x[i]*x[i];
float slx = sumlx - w*x[i]*L[i];
if (slx > 0) {
float sl2 = suml2 - w*L[i]*L[i];
int new_l = nearest_int(x[i] * sl2 / slx);
new_l = MAX(-nmax, MIN(nmax-1, new_l));
if (new_l != L[i]) {
slx += w*x[i]*new_l;
sl2 += w*new_l*new_l;
if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
L[i] = new_l; sumlx = slx; suml2 = sl2;
++n_changed;
}
}
}
}
if (!n_changed) {
break;
}
}
for (int i = 0; i < n; ++i) {
L[i] += nmax;
}
return sumlx / suml2;
}
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
l = MAX(-nmax, MIN(nmax-1, l));
L[i] = l + nmax;
}
return 1/iscale;
}
static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
int ntry, float alpha) {
float min = x[0];
float max = x[0];
for (int i = 1; i < n; ++i) {
if (x[i] < min) min = x[i];
if (x[i] > max) max = x[i];
}
if (max == min) {
for (int i = 0; i < n; ++i) L[i] = 0;
*the_min = 0;
return 0.f;
}
if (min > 0) min = 0;
float iscale = nmax/(max - min);
float scale = 1/iscale;
for (int itry = 0; itry < ntry; ++itry) {
float sumlx = 0; int suml2 = 0;
bool did_change = false;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale*(x[i] - min));
l = MAX(0, MIN(nmax, l));
if (l != L[i]) {
L[i] = l;
did_change = true;
}
sumlx += (x[i] - min)*l;
suml2 += l*l;
}
scale = sumlx/suml2;
float sum = 0;
for (int i = 0; i < n; ++i) {
sum += x[i] - scale*L[i];
}
min = alpha*min + (1 - alpha)*sum/n;
if (min > 0) min = 0;
iscale = 1/scale;
if (!did_change) break;
}
*the_min = -min;
return scale;
}
static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
float rmin, float rdelta, int nstep, bool use_mad) {
float min = x[0];
float max = x[0];
float sum_w = weights[0];
float sum_x = sum_w * x[0];
#ifdef HAVE_BUGGY_APPLE_LINKER
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
for (volatile int i = 1; i < n; ++i) {
#else
for (int i = 1; i < n; ++i) {
#endif
if (x[i] < min) min = x[i];
if (x[i] > max) max = x[i];
float w = weights[i];
sum_w += w;
sum_x += w * x[i];
}
if (min > 0) min = 0;
if (max == min) {
for (int i = 0; i < n; ++i) L[i] = 0;
*the_min = -min;
return 0.f;
}
float iscale = nmax/(max - min);
float scale = 1/iscale;
float best_mad = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale*(x[i] - min));
L[i] = MAX(0, MIN(nmax, l));
float diff = scale * L[i] + min - x[i];
diff = use_mad ? fabsf(diff) : diff * diff;
float w = weights[i];
best_mad += w * diff;
}
if (nstep < 1) {
*the_min = -min;
return scale;
}
for (int is = 0; is <= nstep; ++is) {
iscale = (rmin + rdelta*is + nmax)/(max - min);
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale*(x[i] - min));
l = MAX(0, MIN(nmax, l));
Laux[i] = l;
float w = weights[i];
sum_l += w*l;
sum_l2 += w*l*l;
sum_xl += w*l*x[i];
}
float D = sum_w * sum_l2 - sum_l * sum_l;
if (D > 0) {
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
if (this_min > 0) {
this_min = 0;
this_scale = sum_xl / sum_l2;
}
float mad = 0;
for (int i = 0; i < n; ++i) {
float diff = this_scale * Laux[i] + this_min - x[i];
diff = use_mad ? fabsf(diff) : diff * diff;
float w = weights[i];
mad += w * diff;
}
if (mad < best_mad) {
for (int i = 0; i < n; ++i) {
L[i] = Laux[i];
}
best_mad = mad;
scale = this_scale;
min = this_min;
}
}
}
*the_min = -min;
return scale;
}
#if QK_K == 256
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
} else {
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
#endif
//========================- 2-bit (de)-quantization
static void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
uint8_t L[QK_K];
uint8_t Laux[16];
float weights[16];
float mins[QK_K/16];
float scales[QK_K/16];
const float q4scale = 15.f;
for (int i = 0; i < nb; i++) {
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/16; ++j) {
for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
float scale = scales[j];
if (scale > max_scale) {
max_scale = scale;
}
float min = mins[j];
if (min > max_min) {
max_min = min;
}
}
if (max_scale > 0) {
float iscale = q4scale/max_scale;
for (int j = 0; j < QK_K/16; ++j) {
int l = nearest_int(iscale*scales[j]);
y[i].scales[j] = l;
}
y[i].d = GGML_V3_FP32_TO_FP16(max_scale/q4scale);
} else {
for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
y[i].d = GGML_V3_FP32_TO_FP16(0.f);
}
if (max_min > 0) {
float iscale = q4scale/max_min;
for (int j = 0; j < QK_K/16; ++j) {
int l = nearest_int(iscale*mins[j]);
y[i].scales[j] |= (l << 4);
}
y[i].dmin = GGML_V3_FP32_TO_FP16(max_min/q4scale);
} else {
y[i].dmin = GGML_V3_FP32_TO_FP16(0.f);
}
for (int j = 0; j < QK_K/16; ++j) {
const float d = GGML_V3_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);
if (!d) continue;
const float dm = GGML_V3_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int((x[16*j + ii] + dm)/d);
l = MAX(0, MIN(3, l));
L[16*j + ii] = l;
}
}
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
#else
for (int l = 0; l < 16; ++l) {
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
}
#endif
x += QK_K;
}
}
static void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
const float min = GGML_V3_FP16_TO_FP32(x[i].dmin);
const uint8_t * q = x[i].qs;
#if QK_K == 256
int is = 0;
float dl, ml;
for (int n = 0; n < QK_K; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
uint8_t sc = x[i].scales[is++];
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
sc = x[i].scales[is++];
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
shift += 2;
}
q += 32;
}
#else
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
for (int l = 0; l < 16; ++l) {
y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
}
y += QK_K;
#endif
}
}
static void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
quantize_row_q2_K_reference(x, vy, k);
}
size_t ggml_v3_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
quantize_row_q2_K_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_q2_K));
}
//========================= 3-bit (de)-quantization
static void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
int8_t L[QK_K];
float scales[QK_K / 16];
for (int i = 0; i < nb; i++) {
float max_scale = 0;
float amax = 0;
for (int j = 0; j < QK_K/16; ++j) {
scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
float scale = fabsf(scales[j]);
if (scale > amax) {
amax = scale; max_scale = scales[j];
}
}
#if QK_K == 256
memset(y[i].scales, 0, 12);
if (max_scale) {
float iscale = -32.f/max_scale;
for (int j = 0; j < QK_K/16; ++j) {
int8_t l = nearest_int(iscale*scales[j]);
l = MAX(-32, MIN(31, l)) + 32;
if (j < 8) {
y[i].scales[j] = l & 0xF;
} else {
y[i].scales[j-8] |= ((l & 0xF) << 4);
}
l >>= 4;
y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
}
y[i].d = GGML_V3_FP32_TO_FP16(1/iscale);
} else {
y[i].d = GGML_V3_FP32_TO_FP16(0.f);
}
int8_t sc;
for (int j = 0; j < QK_K/16; ++j) {
sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
float d = GGML_V3_FP16_TO_FP32(y[i].d) * sc;
if (!d) {
continue;
}
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int(x[16*j + ii]/d);
l = MAX(-4, MIN(3, l));
L[16*j + ii] = l + 4;
}
}
#else
if (max_scale) {
float iscale = -8.f/max_scale;
for (int j = 0; j < QK_K/16; j+=2) {
int l1 = nearest_int(iscale*scales[j]);
l1 = 8 + MAX(-8, MIN(7, l1));
int l2 = nearest_int(iscale*scales[j+1]);
l2 = 8 + MAX(-8, MIN(7, l2));
y[i].scales[j/2] = l1 | (l2 << 4);
}
y[i].d = GGML_V3_FP32_TO_FP16(1/iscale);
} else {
for (int j = 0; j < QK_K/16; j+=2) {
y[i].scales[j/2] = 0;
}
y[i].d = GGML_V3_FP32_TO_FP16(0.f);
}
for (int j = 0; j < QK_K/16; ++j) {
int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
float d = GGML_V3_FP16_TO_FP32(y[i].d) * (s - 8);
if (!d) {
continue;
}
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int(x[16*j + ii]/d);
l = MAX(-4, MIN(3, l));
L[16*j + ii] = l + 4;
}
}
#endif
memset(y[i].hmask, 0, QK_K/8);
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
int m = 0;
uint8_t hm = 1;
for (int j = 0; j < QK_K; ++j) {
if (L[j] > 3) {
y[i].hmask[m] |= hm;
L[j] -= 4;
}
if (++m == QK_K/8) {
m = 0; hm <<= 1;
}
}
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
#else
for (int l = 0; l < 16; ++l) {
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
}
#endif
x += QK_K;
}
}
#if QK_K == 256
static void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
for (int i = 0; i < nb; i++) {
const float d_all = GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q = x[i].qs;
const uint8_t * restrict hm = x[i].hmask;
uint8_t m = 1;
memcpy(aux, x[i].scales, 12);
uint32_t tmp = aux[2];
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
int is = 0;
float dl;
for (int n = 0; n < QK_K; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
}
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
}
shift += 2;
m <<= 1;
}
q += 32;
}
}
}
#else
static void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
assert(QK_K == 64);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const float d_all = GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q = x[i].qs;
const uint8_t * restrict hm = x[i].hmask;
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
for (int l=0; l<8; ++l) {
uint8_t h = hm[l];
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
}
y += QK_K;
}
}
#endif
static void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
quantize_row_q3_K_reference(x, vy, k);
}
size_t ggml_v3_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
quantize_row_q3_K_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_q3_K));
}
// ====================== 4-bit (de)-quantization
static void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
uint8_t L[QK_K];
uint8_t Laux[32];
float weights[32];
float mins[QK_K/32];
float scales[QK_K/32];
for (int i = 0; i < nb; i++) {
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
float sum_x2 = 0;
for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
float av_x = sqrtf(sum_x2/32);
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
float scale = scales[j];
if (scale > max_scale) {
max_scale = scale;
}
float min = mins[j];
if (min > max_min) {
max_min = min;
}
}
#if QK_K == 256
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
for (int j = 0; j < QK_K/32; ++j) {
uint8_t ls = nearest_int(inv_scale*scales[j]);
uint8_t lm = nearest_int(inv_min*mins[j]);
ls = MIN(63, ls);
lm = MIN(63, lm);
if (j < 4) {
y[i].scales[j] = ls;
y[i].scales[j+4] = lm;
} else {
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
y[i].scales[j-4] |= ((ls >> 4) << 6);
y[i].scales[j-0] |= ((lm >> 4) << 6);
}
}
y[i].d = GGML_V3_FP32_TO_FP16(max_scale/63.f);
y[i].dmin = GGML_V3_FP32_TO_FP16(max_min/63.f);
uint8_t sc, m;
for (int j = 0; j < QK_K/32; ++j) {
get_scale_min_k4(j, y[i].scales, &sc, &m);
const float d = GGML_V3_FP16_TO_FP32(y[i].d) * sc;
if (!d) continue;
const float dm = GGML_V3_FP16_TO_FP32(y[i].dmin) * m;
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d);
l = MAX(0, MIN(15, l));
L[32*j + ii] = l;
}
}
#else
const float s_factor = 15.f;
float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
int d1 = nearest_int(inv_scale*scales[0]);
int m1 = nearest_int(inv_min*mins[0]);
int d2 = nearest_int(inv_scale*scales[1]);
int m2 = nearest_int(inv_min*mins[1]);
y[i].scales[0] = d1 | (m1 << 4);
y[i].scales[1] = d2 | (m2 << 4);
y[i].d[0] = GGML_V3_FP32_TO_FP16(max_scale/s_factor);
y[i].d[1] = GGML_V3_FP32_TO_FP16(max_min/s_factor);
float sumlx = 0;
int suml2 = 0;
for (int j = 0; j < QK_K/32; ++j) {
const uint8_t sd = y[i].scales[j] & 0xF;
const uint8_t sm = y[i].scales[j] >> 4;
const float d = GGML_V3_FP16_TO_FP32(y[i].d[0]) * sd;
if (!d) continue;
const float m = GGML_V3_FP16_TO_FP32(y[i].d[1]) * sm;
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + m)/d);
l = MAX(0, MIN(15, l));
L[32*j + ii] = l;
sumlx += (x[32*j + ii] + m)*l*sd;
suml2 += l*l*sd*sd;
}
}
if (suml2) {
y[i].d[0] = GGML_V3_FP32_TO_FP16(sumlx/suml2);
}
#endif
uint8_t * q = y[i].qs;
for (int j = 0; j < QK_K; j += 64) {
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
q += 32;
}
x += QK_K;
}
}
static void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const uint8_t * q = x[i].qs;
#if QK_K == 256
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
const float min = GGML_V3_FP16_TO_FP32(x[i].dmin);
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
#else
const float dall = GGML_V3_FP16_TO_FP32(x[i].d[0]);
const float mall = GGML_V3_FP16_TO_FP32(x[i].d[1]);
const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
for (int l = 0; l < 32; ++l) {
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
y[l+32] = d2 * (q[l] >> 4) - m2;
}
y += QK_K;
#endif
}
}
static void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_q4_K * restrict y = vy;
quantize_row_q4_K_reference(x, y, k);
}
size_t ggml_v3_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
assert(k % QK_K == 0);
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
quantize_row_q4_K_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_q4_K));
}
// ====================== 5-bit (de)-quantization
static void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
#if QK_K == 256
uint8_t L[QK_K];
float mins[QK_K/32];
float scales[QK_K/32];
float weights[32];
uint8_t Laux[32];
#else
int8_t L[QK_K];
float scales[QK_K/16];
#endif
for (int i = 0; i < nb; i++) {
#if QK_K == 256
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
//scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
float sum_x2 = 0;
for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
float av_x = sqrtf(sum_x2/32);
for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
float scale = scales[j];
if (scale > max_scale) {
max_scale = scale;
}
float min = mins[j];
if (min > max_min) {
max_min = min;
}
}
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
for (int j = 0; j < QK_K/32; ++j) {
uint8_t ls = nearest_int(inv_scale*scales[j]);
uint8_t lm = nearest_int(inv_min*mins[j]);
ls = MIN(63, ls);
lm = MIN(63, lm);
if (j < 4) {
y[i].scales[j] = ls;
y[i].scales[j+4] = lm;
} else {
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
y[i].scales[j-4] |= ((ls >> 4) << 6);
y[i].scales[j-0] |= ((lm >> 4) << 6);
}
}
y[i].d = GGML_V3_FP32_TO_FP16(max_scale/63.f);
y[i].dmin = GGML_V3_FP32_TO_FP16(max_min/63.f);
uint8_t sc, m;
for (int j = 0; j < QK_K/32; ++j) {
get_scale_min_k4(j, y[i].scales, &sc, &m);
const float d = GGML_V3_FP16_TO_FP32(y[i].d) * sc;
if (!d) continue;
const float dm = GGML_V3_FP16_TO_FP32(y[i].dmin) * m;
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d);
l = MAX(0, MIN(31, l));
L[32*j + ii] = l;
}
}
uint8_t * restrict qh = y[i].qh;
uint8_t * restrict ql = y[i].qs;
memset(qh, 0, QK_K/8);
uint8_t m1 = 1, m2 = 2;
for (int n = 0; n < QK_K; n += 64) {
for (int j = 0; j < 32; ++j) {
int l1 = L[n + j];
if (l1 > 15) {
l1 -= 16; qh[j] |= m1;
}
int l2 = L[n + j + 32];
if (l2 > 15) {
l2 -= 16; qh[j] |= m2;
}
ql[j] = l1 | (l2 << 4);
}
m1 <<= 2; m2 <<= 2;
ql += 32;
}
#else
float max_scale = 0, amax = 0;
for (int j = 0; j < QK_K/16; ++j) {
scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
float abs_scale = fabsf(scales[j]);
if (abs_scale > amax) {
amax = abs_scale;
max_scale = scales[j];
}
}
float iscale = -128.f/max_scale;
for (int j = 0; j < QK_K/16; ++j) {
int l = nearest_int(iscale*scales[j]);
y[i].scales[j] = MAX(-128, MIN(127, l));
}
y[i].d = GGML_V3_FP32_TO_FP16(1/iscale);
for (int j = 0; j < QK_K/16; ++j) {
const float d = GGML_V3_FP16_TO_FP32(y[i].d) * y[i].scales[j];
if (!d) continue;
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int(x[16*j + ii]/d);
l = MAX(-16, MIN(15, l));
L[16*j + ii] = l + 16;
}
}
uint8_t * restrict qh = y[i].qh;
uint8_t * restrict ql = y[i].qs;
memset(qh, 0, QK_K/8);
for (int j = 0; j < 32; ++j) {
int jm = j%8;
int is = j/8;
int l1 = L[j];
if (l1 > 15) {
l1 -= 16; qh[jm] |= (1 << is);
}
int l2 = L[j + 32];
if (l2 > 15) {
l2 -= 16; qh[jm] |= (1 << (4 + is));
}
ql[j] = l1 | (l2 << 4);
}
#endif
x += QK_K;
}
}
static void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const uint8_t * ql = x[i].qs;
const uint8_t * qh = x[i].qh;
#if QK_K == 256
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
const float min = GGML_V3_FP16_TO_FP32(x[i].dmin);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
for (int j = 0; j < QK_K; j += 64) {
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
#else
float d = GGML_V3_FP16_TO_FP32(x[i].d);
const int8_t * restrict s = x[i].scales;
for (int l = 0; l < 8; ++l) {
y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
}
y += QK_K;
#endif
}
}
static void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_q5_K * restrict y = vy;
quantize_row_q5_K_reference(x, y, k);
}
size_t ggml_v3_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
assert(k % QK_K == 0);
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
quantize_row_q5_K_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_q5_K));
}
// ====================== 6-bit (de)-quantization
static void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
int8_t L[QK_K];
float scales[QK_K/16];
for (int i = 0; i < nb; i++) {
float max_scale = 0;
float max_abs_scale = 0;
for (int ib = 0; ib < QK_K/16; ++ib) {
const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
scales[ib] = scale;
const float abs_scale = fabsf(scale);
if (abs_scale > max_abs_scale) {
max_abs_scale = abs_scale;
max_scale = scale;
}
}
if (!max_abs_scale) {
memset(&y[i], 0, sizeof(block_q6_K));
y[i].d = GGML_V3_FP32_TO_FP16(0.f);
x += QK_K;
continue;
}
float iscale = -128.f/max_scale;
y[i].d = GGML_V3_FP32_TO_FP16(1/iscale);
for (int ib = 0; ib < QK_K/16; ++ib) {
y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
}
for (int j = 0; j < QK_K/16; ++j) {
float d = GGML_V3_FP16_TO_FP32(y[i].d) * y[i].scales[j];
if (!d) {
continue;
}
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int(x[16*j + ii]/d);
l = MAX(-32, MIN(31, l));
L[16*j + ii] = l + 32;
}
}
uint8_t * restrict ql = y[i].ql;
uint8_t * restrict qh = y[i].qh;
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[j + l + 0] & 0xF;
const uint8_t q2 = L[j + l + 32] & 0xF;
const uint8_t q3 = L[j + l + 64] & 0xF;
const uint8_t q4 = L[j + l + 96] & 0xF;
ql[l+ 0] = q1 | (q3 << 4);
ql[l+32] = q2 | (q4 << 4);
qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
}
ql += 64;
qh += 32;
}
#else
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[l + 0] & 0xF;
const uint8_t q2 = L[l + 32] & 0xF;
ql[l] = q1 | (q2 << 4);
}
for (int l = 0; l < 16; ++l) {
qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
}
#endif
x += QK_K;
}
}
static void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict ql = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
#if QK_K == 256
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l + 0] = d * sc[is + 0] * q1;
y[l + 32] = d * sc[is + 2] * q2;
y[l + 64] = d * sc[is + 4] * q3;
y[l + 96] = d * sc[is + 6] * q4;
}
y += 128;
ql += 64;
qh += 32;
sc += 8;
}
#else
for (int l = 0; l < 16; ++l) {
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l+ 0] = d * sc[0] * q1;
y[l+16] = d * sc[1] * q2;
y[l+32] = d * sc[2] * q3;
y[l+48] = d * sc[3] * q4;
}
y += 64;
#endif
}
}
static void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_q6_K * restrict y = vy;
quantize_row_q6_K_reference(x, y, k);
}
size_t ggml_v3_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK_K == 0);
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
quantize_row_q6_K_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_q6_K));
}
// ====================== "True" 2-bit (de)-quantization
static const uint64_t iq2xxs_grid[256] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
};
static const uint64_t iq2xs_grid[512] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
};
static const uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
};
static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
static void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
(void)x;
(void)y;
(void)k;
assert(k % QK_K == 0);
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
}
static void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
y += 8;
}
}
}
}
static void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_iq2_xxs * restrict y = vy;
quantize_row_iq2_xxs_reference(x, y, k);
}
size_t ggml_v3_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK_K == 0);
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
quantize_row_iq2_xxs_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_iq2_xxs));
}
// ====================== 2.3125 bpw (de)-quantization
static void quantize_row_iq2_xs_reference(const float * restrict x, block_iq2_xs * restrict y, int k) {
(void)x;
(void)y;
(void)k;
assert(k % QK_K == 0);
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
}
static void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
float db[2];
for (int i = 0; i < nb; i++) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
for (int j = 0; j < 8; ++j) {
y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
y += 8;
}
}
}
}
static void quantize_row_iq2_xs(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_iq2_xs * restrict y = vy;
quantize_row_iq2_xs_reference(x, y, k);
}
size_t ggml_v3_quantize_iq2_xs(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK_K == 0);
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_iq2_xs * restrict y = (block_iq2_xs *)dst + j/QK_K;
quantize_row_iq2_xs_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_iq2_xs));
}
//===================================== Q8_K ==============================================
static void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
float max = 0;
float amax = 0;
for (int j = 0; j < QK_K; ++j) {
float ax = fabsf(x[j]);
if (ax > amax) {
amax = ax; max = x[j];
}
}
if (!amax) {
y[i].d = 0;
memset(y[i].qs, 0, QK_K);
x += QK_K;
continue;
}
//const float iscale = -128.f/max;
// We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
const float iscale = -127.f/max;
for (int j = 0; j < QK_K; ++j) {
int v = nearest_int(iscale*x[j]);
y[i].qs[j] = MIN(127, v);
}
for (int j = 0; j < QK_K/16; ++j) {
int sum = 0;
for (int ii = 0; ii < 16; ++ii) {
sum += y[i].qs[j*16 + ii];
}
y[i].bsums[j] = sum;
}
y[i].d = 1/iscale;
x += QK_K;
}
}
static void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
for (int j = 0; j < QK_K; ++j) {
*y++ = x[i].d * x[i].qs[j];
}
}
}
static void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
quantize_row_q8_K_reference(x, y, k);
}
//===================================== Dot ptoducts =================================
//
// Helper functions
//
#if __AVX__ || __AVX2__ || __AVX512F__
// shuffles to pick the required scales in dot products
static inline __m256i get_scale_shuffle_q3k(int i) {
static const uint8_t k_shuffle[128] = {
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
};
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
}
static inline __m256i get_scale_shuffle_k4(int i) {
static const uint8_t k_shuffle[256] = {
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
};
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
}
static inline __m128i get_scale_shuffle(int i) {
static const uint8_t k_shuffle[128] = {
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
};
return _mm_loadu_si128((const __m128i*)k_shuffle + i);
}
#endif
static void ggml_v3_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
const block_q4_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
assert(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
// dot product into int32x4_t
const int32x4_t p_0 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
const int32x4_t p_1 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_V3_FP16_TO_FP32(x0->d)*GGML_V3_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_V3_FP16_TO_FP32(x1->d)*GGML_V3_FP16_TO_FP32(y1->d));
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
/* Compute combined scale for the block */
const __m256 d = _mm256_set1_ps( GGML_V3_FP16_TO_FP32(x[i].d) * GGML_V3_FP16_TO_FP32(y[i].d) );
__m256i bx = bytes_from_nibbles_32(x[i].qs);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( d, q, acc );
}
*s = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
// Compute combined scale for the block
const __m256 d = _mm256_set1_ps( GGML_V3_FP16_TO_FP32(x[i].d) * GGML_V3_FP16_TO_FP32(y[i].d) );
const __m128i lowMask = _mm_set1_epi8(0xF);
const __m128i off = _mm_set1_epi8(8);
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
__m128i bx = _mm_and_si128(lowMask, tmp);
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
bx = _mm_sub_epi8(bx, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
bx = _mm_sub_epi8(bx, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
}
*s = hsum_float_8(acc);
#elif defined(__SSSE3__)
// set constants
const __m128i lowMask = _mm_set1_epi8(0xF);
const __m128i off = _mm_set1_epi8(8);
// Initialize accumulator with zeros
__m128 acc_0 = _mm_setzero_ps();
__m128 acc_1 = _mm_setzero_ps();
__m128 acc_2 = _mm_setzero_ps();
__m128 acc_3 = _mm_setzero_ps();
// First round without accumulation
{
_mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
_mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
// Compute combined scale for the block 0 and 1
const __m128 d_0_1 = _mm_set1_ps( GGML_V3_FP16_TO_FP32(x[0].d) * GGML_V3_FP16_TO_FP32(y[0].d) );
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
bx_1 = _mm_sub_epi8(bx_1, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
_mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
_mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
// Compute combined scale for the block 2 and 3
const __m128 d_2_3 = _mm_set1_ps( GGML_V3_FP16_TO_FP32(x[1].d) * GGML_V3_FP16_TO_FP32(y[1].d) );
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
bx_2 = _mm_sub_epi8(bx_2, off);
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
bx_3 = _mm_sub_epi8(bx_3, off);
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
// Convert int32_t to float
__m128 p0 = _mm_cvtepi32_ps(i32_0);
__m128 p1 = _mm_cvtepi32_ps(i32_1);
__m128 p2 = _mm_cvtepi32_ps(i32_2);
__m128 p3 = _mm_cvtepi32_ps(i32_3);
// Apply the scale
acc_0 = _mm_mul_ps( d_0_1, p0 );
acc_1 = _mm_mul_ps( d_0_1, p1 );
acc_2 = _mm_mul_ps( d_2_3, p2 );
acc_3 = _mm_mul_ps( d_2_3, p3 );
}
assert(nb % 2 == 0); // TODO: handle odd nb
// Main loop
for (int i = 2; i < nb; i+=2) {
_mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
// Compute combined scale for the block 0 and 1
const __m128 d_0_1 = _mm_set1_ps( GGML_V3_FP16_TO_FP32(x[i].d) * GGML_V3_FP16_TO_FP32(y[i].d) );
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
bx_1 = _mm_sub_epi8(bx_1, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
// Compute combined scale for the block 2 and 3
const __m128 d_2_3 = _mm_set1_ps( GGML_V3_FP16_TO_FP32(x[i + 1].d) * GGML_V3_FP16_TO_FP32(y[i + 1].d) );
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
bx_2 = _mm_sub_epi8(bx_2, off);
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
bx_3 = _mm_sub_epi8(bx_3, off);
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
// Convert int32_t to float
__m128 p0 = _mm_cvtepi32_ps(i32_0);
__m128 p1 = _mm_cvtepi32_ps(i32_1);
__m128 p2 = _mm_cvtepi32_ps(i32_2);
__m128 p3 = _mm_cvtepi32_ps(i32_3);
// Apply the scale
__m128 p0_d = _mm_mul_ps( d_0_1, p0 );
__m128 p1_d = _mm_mul_ps( d_0_1, p1 );
__m128 p2_d = _mm_mul_ps( d_2_3, p2 );
__m128 p3_d = _mm_mul_ps( d_2_3, p3 );
// Acummulate
acc_0 = _mm_add_ps(p0_d, acc_0);
acc_1 = _mm_add_ps(p1_d, acc_1);
acc_2 = _mm_add_ps(p2_d, acc_2);
acc_3 = _mm_add_ps(p3_d, acc_3);
}
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
#elif defined(__riscv_v_intrinsic)
float sumf = 0.0;
size_t vl = __riscv_vsetvl_e8m1(qk/2);
for (int i = 0; i < nb; i++) {
// load elements
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
// mask and store lower part of x, and then upper part
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
// subtract offset
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
sumf += sumi*GGML_V3_FP16_TO_FP32(x[i].d)*GGML_V3_FP16_TO_FP32(y[i].d);
}
*s = sumf;
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi = 0;
for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0x0F) - 8;
const int v1 = (x[i].qs[j] >> 4) - 8;
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
}
sumf += sumi*GGML_V3_FP16_TO_FP32(x[i].d)*GGML_V3_FP16_TO_FP32(y[i].d);
}
*s = sumf;
#endif
}
static void ggml_v3_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_1;
const int nb = n / qk;
assert(n % qk == 0);
const block_q4_1 * restrict x = vx;
const block_q8_1 * restrict y = vy;
// TODO: add WASM SIMD
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
float summs = 0;
assert(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict x1 = &x[i + 1];
const block_q8_1 * restrict y0 = &y[i + 0];
const block_q8_1 * restrict y1 = &y[i + 1];
summs += GGML_V3_FP16_TO_FP32(x0->m) * y0->s + GGML_V3_FP16_TO_FP32(x1->m) * y1->s;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
// dot product into int32x4_t
const int32x4_t p_0 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
const int32x4_t p_1 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_V3_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_V3_FP16_TO_FP32(x1->d)*y1->d);
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
#elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
float summs = 0;
// Main loop
for (int i = 0; i < nb; ++i) {
const float d0 = GGML_V3_FP16_TO_FP32(x[i].d);
const float d1 = y[i].d;
summs += GGML_V3_FP16_TO_FP32(x[i].m) * y[i].s;
const __m256 d0v = _mm256_set1_ps( d0 );
const __m256 d1v = _mm256_set1_ps( d1 );
// Compute combined scales
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
const __m256 xy = mul_sum_us8_pairs_float(bx, by);
// Accumulate d0*d1*x*y
#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d0d1, xy, acc );
#else
acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
#endif
}
*s = hsum_float_8(acc) + summs;
#elif defined(__riscv_v_intrinsic)
float sumf = 0.0;
size_t vl = __riscv_vsetvl_e8m1(qk/2);
for (int i = 0; i < nb; i++) {
// load elements
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
// mask and store lower part of x, and then upper part
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
sumf += (GGML_V3_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_V3_FP16_TO_FP32(x[i].m)*y[i].s;
}
*s = sumf;
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi = 0;
for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0x0F);
const int v1 = (x[i].qs[j] >> 4);
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
}
sumf += (GGML_V3_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_V3_FP16_TO_FP32(x[i].m)*y[i].s;
}
*s = sumf;
#endif
}
static void ggml_v3_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
assert(qk == QK5_0);
const block_q5_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
uint32_t qh0;
uint32_t qh1;
uint64_t tmp0[4];
uint64_t tmp1[4];
assert(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q5_0 * restrict x0 = &x[i];
const block_q5_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i];
const block_q8_0 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0x0F);
// extract the 5th bit via lookup table ((!b) << 4)
memcpy(&qh0, x0->qh, sizeof(qh0));
memcpy(&qh1, x1->qh, sizeof(qh1));
tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
tmp0[3] = table_b2b_1[(qh0 >> 24) ];
tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
tmp1[3] = table_b2b_1[(qh1 >> 24) ];
const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 4-bit -> 8-bit
int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_V3_FP16_TO_FP32(x0->d)*GGML_V3_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_V3_FP16_TO_FP32(x1->d)*GGML_V3_FP16_TO_FP32(y1->d));
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__wasm_simd128__)
v128_t sumv = wasm_f32x4_splat(0.0f);
uint32_t qh;
uint64_t tmp[4];
// TODO: check if unrolling this is better
for (int i = 0; i < nb; ++i) {
const block_q5_0 * restrict x0 = &x[i];
const block_q8_0 * restrict y0 = &y[i];
const v128_t m4b = wasm_i8x16_splat(0x0F);
// extract the 5th bit
memcpy(&qh, x0->qh, sizeof(qh));
tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_1[(qh >> 24) ];
const v128_t qhl = wasm_v128_load(tmp + 0);
const v128_t qhh = wasm_v128_load(tmp + 2);
const v128_t v0 = wasm_v128_load(x0->qs);
// 4-bit -> 8-bit
const v128_t v0l = wasm_v128_and (v0, m4b);
const v128_t v0h = wasm_u8x16_shr(v0, 4);
// add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
// load y
const v128_t v1l = wasm_v128_load(y0->qs);
const v128_t v1h = wasm_v128_load(y0->qs + 16);
// int8x16 -> int16x8
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
// dot product
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
wasm_i32x4_add(
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
wasm_f32x4_splat(GGML_V3_FP16_TO_FP32(x0->d) * GGML_V3_FP16_TO_FP32(y0->d))));
}
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m256 d = _mm256_set1_ps(GGML_V3_FP16_TO_FP32(x[i].d) * GGML_V3_FP16_TO_FP32(y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
bx = _mm256_or_si256(bx, bxhi);
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps(d, q, acc);
}
*s = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
__m128i mask = _mm_set1_epi8((char)0xF0);
// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m256 d = _mm256_set1_ps(GGML_V3_FP16_TO_FP32(x[i].d) * GGML_V3_FP16_TO_FP32(y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_andnot_si128(bxhil, mask);
bxhih = _mm_andnot_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx);
__m128i bxh = _mm256_extractf128_si256(bx, 1);
bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih);
bx = MM256_SET_M128I(bxh, bxl);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
}
*s = hsum_float_8(acc);
#elif defined(__riscv_v_intrinsic)
float sumf = 0.0;
uint32_t qh;
size_t vl = __riscv_vsetvl_e8m1(qk/2);
// These temporary registers are for masking and shift operations
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
for (int i = 0; i < nb; i++) {
memcpy(&qh, x[i].qh, sizeof(uint32_t));
// ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
// ((qh & (1u << (j + 16))) >> (j + 12));
vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
// narrowing
vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
// load
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
sumf += (GGML_V3_FP16_TO_FP32(x[i].d)*GGML_V3_FP16_TO_FP32(y[i].d)) * sumi;
}
*s = sumf;
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
int sumi = 0;
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
}
sumf += (GGML_V3_FP16_TO_FP32(x[i].d)*GGML_V3_FP16_TO_FP32(y[i].d)) * sumi;
}
*s = sumf;
#endif
}
static void ggml_v3_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_1;
const int nb = n / qk;
assert(n % qk == 0);
assert(qk == QK5_1);
const block_q5_1 * restrict x = vx;
const block_q8_1 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
float summs0 = 0.0f;
float summs1 = 0.0f;
uint32_t qh0;
uint32_t qh1;
uint64_t tmp0[4];
uint64_t tmp1[4];
assert(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q5_1 * restrict x0 = &x[i];
const block_q5_1 * restrict x1 = &x[i + 1];
const block_q8_1 * restrict y0 = &y[i];
const block_q8_1 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0x0F);
summs0 += GGML_V3_FP16_TO_FP32(x0->m) * y0->s;
summs1 += GGML_V3_FP16_TO_FP32(x1->m) * y1->s;
// extract the 5th bit via lookup table ((b) << 4)
memcpy(&qh0, x0->qh, sizeof(qh0));
memcpy(&qh1, x1->qh, sizeof(qh1));
tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
tmp0[3] = table_b2b_0[(qh0 >> 24) ];
tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
tmp1[3] = table_b2b_0[(qh1 >> 24) ];
const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// add high bit
const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_V3_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
ggml_v3_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_V3_FP16_TO_FP32(x1->d)*y1->d);
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
#elif defined(__wasm_simd128__)
v128_t sumv = wasm_f32x4_splat(0.0f);
float summs = 0.0f;
uint32_t qh;
uint64_t tmp[4];
// TODO: check if unrolling this is better
for (int i = 0; i < nb; ++i) {
const block_q5_1 * restrict x0 = &x[i];
const block_q8_1 * restrict y0 = &y[i];
summs += GGML_V3_FP16_TO_FP32(x0->m) * y0->s;
const v128_t m4b = wasm_i8x16_splat(0x0F);
// extract the 5th bit
memcpy(&qh, x0->qh, sizeof(qh));
tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_0[(qh >> 24) ];
const v128_t qhl = wasm_v128_load(tmp + 0);
const v128_t qhh = wasm_v128_load(tmp + 2);
const v128_t v0 = wasm_v128_load(x0->qs);
// 4-bit -> 8-bit
const v128_t v0l = wasm_v128_and (v0, m4b);
const v128_t v0h = wasm_u8x16_shr(v0, 4);
// add high bit
const v128_t v0lf = wasm_v128_or(v0l, qhl);
const v128_t v0hf = wasm_v128_or(v0h, qhh);
// load y
const v128_t v1l = wasm_v128_load(y0->qs);
const v128_t v1h = wasm_v128_load(y0->qs + 16);
// int8x16 -> int16x8
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
// dot product
sumv = wasm_f32x4_add(sumv,
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
wasm_f32x4_splat(GGML_V3_FP16_TO_FP32(x0->d) * y0->d)));
}
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
float summs = 0.0f;
// Main loop
for (int i = 0; i < nb; i++) {
const __m256 dx = _mm256_set1_ps(GGML_V3_FP16_TO_FP32(x[i].d));
summs += GGML_V3_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs);
__m256i bxhi = bytes_from_bits_32(x[i].qh);
bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
bx = _mm256_or_si256(bx, bxhi);
const __m256 dy = _mm256_set1_ps(y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_us8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
}
*s = hsum_float_8(acc) + summs;
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
__m128i mask = _mm_set1_epi8(0x10);
float summs = 0.0f;
// Main loop
for (int i = 0; i < nb; i++) {
const __m256 dx = _mm256_set1_ps(GGML_V3_FP16_TO_FP32(x[i].d));
summs += GGML_V3_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_and_si128(bxhil, mask);
bxhih = _mm_and_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx);
__m128i bxh = _mm256_extractf128_si256(bx, 1);
bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih);
bx = MM256_SET_M128I(bxh, bxl);
const __m256 dy = _mm256_set1_ps(y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_us8_pairs_float(bx, by);
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
}
*s = hsum_float_8(acc) + summs;
#elif defined(__riscv_v_intrinsic)
float sumf = 0.0;
uint32_t qh;
size_t vl = __riscv_vsetvl_e8m1(qk/2);
// temporary registers for shift operations
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
for (int i = 0; i < nb; i++) {
memcpy(&qh, x[i].qh, sizeof(uint32_t));
// load qh
vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
// ((qh >> (j + 0)) << 4) & 0x10;
vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
// ((qh >> (j + 12)) ) & 0x10;
vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
// narrowing
vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
// load
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
sumf += (GGML_V3_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_V3_FP16_TO_FP32(x[i].m)*y[i].s;
}
*s = sumf;
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
int sumi = 0;
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
}
sumf += (GGML_V3_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_V3_FP16_TO_FP32(x[i].m)*y[i].s;
}
*s = sumf;
#endif
}
static void ggml_v3_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
const block_q8_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
assert(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q8_0 * restrict x0 = &x[i + 0];
const block_q8_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
const int8x16_t x0_0 = vld1q_s8(x0->qs);
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
const int8x16_t x1_0 = vld1q_s8(x1->qs);
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
// load y
const int8x16_t y0_0 = vld1q_s8(y0->qs);
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
const int8x16_t y1_0 = vld1q_s8(y1->qs);
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
ggml_v3_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
ggml_v3_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_V3_FP16_TO_FP32(x0->d)*GGML_V3_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
ggml_v3_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
ggml_v3_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_V3_FP16_TO_FP32(x1->d)*GGML_V3_FP16_TO_FP32(y1->d));
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
// Compute combined scale for the block
const __m256 d = _mm256_set1_ps(GGML_V3_FP16_TO_FP32(x[i].d) * GGML_V3_FP16_TO_FP32(y[i].d));
__m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
// Multiply q with scale and accumulate
#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d, q, acc );
#else
acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
#endif
}
*s = hsum_float_8(acc);
#elif defined(__riscv_v_intrinsic)
float sumf = 0.0;
size_t vl = __riscv_vsetvl_e8m1(qk);
for (int i = 0; i < nb; i++) {
// load elements
vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
sumf += sumi*(GGML_V3_FP16_TO_FP32(x[i].d)*GGML_V3_FP16_TO_FP32(y[i].d));
}
*s = sumf;
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi = 0;
for (int j = 0; j < qk; j++) {
sumi += x[i].qs[j]*y[i].qs[j];
}
sumf += sumi*(GGML_V3_FP16_TO_FP32(x[i].d)*GGML_V3_FP16_TO_FP32(y[i].d));
}
*s = sumf;
#endif
}
#if QK_K == 256
static void ggml_v3_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const block_q2_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m3 = vdupq_n_u8(0x3);
const uint8x16_t m4 = vdupq_n_u8(0xF);
const int32x4_t vzero = vdupq_n_s32(0);
ggml_v3_int8x16x2_t q2bytes;
uint8_t aux[16];
float sum = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint8_t * restrict sc = x[i].scales;
const uint8x16_t mins_and_scales = vld1q_u8(sc);
const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
vst1q_u8(aux, scales);
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
const ggml_v3_int16x8x2_t q8sums = ggml_v3_vld1q_s16_x2(y[i].bsums);
const ggml_v3_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
int isum = 0;
int is = 0;
// We use this macro instead of a function call because for some reason
// the code runs 2-3% slower, even if the function is declared inline
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
q8bytes = ggml_v3_vld1q_s8_x2(q8); q8 += 32;\
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
MULTIPLY_ACCUM_WITH_SCALE((index));
for (int j = 0; j < QK_K/128; ++j) {
const ggml_v3_uint8x16x2_t q2bits = ggml_v3_vld1q_u8_x2(q2); q2 += 32;
ggml_v3_int8x16x2_t q8bytes = ggml_v3_vld1q_s8_x2(q8); q8 += 32;
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
MULTIPLY_ACCUM_WITH_SCALE(0);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
is += 8;
}
sum += d * isum;
}
*s = sum;
#elif defined __AVX2__
const __m256i m3 = _mm256_set1_epi8(3);
const __m128i m4 = _mm_set1_epi8(0xF);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
const __m256i mins = _mm256_cvtepi8_epi16(mins8);
const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
__m256i sumi = _mm256_setzero_si256();
for (int j = 0; j < QK_K/128; ++j) {
const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
__m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
__m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
__m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
__m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
p0 = _mm256_add_epi32(p0, p1);
p2 = _mm256_add_epi32(p2, p3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
*s = hsum_float_8(acc);
#elif defined __AVX__
const __m128i m3 = _mm_set1_epi8(0x3);
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i m2 = _mm_set1_epi8(0x2);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float dall = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
// load mins and scales from block_q2_K.scales[QK_K/16]
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
// summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
// sumf += -dmin * summs in 32bits*8
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
const __m128i scales[2] = { scales_0, scales_1 };
__m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128();
for (int j = 0; j < QK_K/128; ++j) {
// load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
// load 2bits*16*8 from block_q2_K.qs[QK_K/4]
__m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
const __m128i q2_1 = _mm_and_si128(q2bits, m3);
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
// isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
__m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
__m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
__m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
__m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
__m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
__m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
__m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
__m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
// isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
__m128i shuffle = _mm_set1_epi16(0x0100);
p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
shuffle = _mm_add_epi16(shuffle, m2);
p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
shuffle = _mm_add_epi16(shuffle, m2);
p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
shuffle = _mm_add_epi16(shuffle, m2);
p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
shuffle = _mm_add_epi16(shuffle, m2);
p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
shuffle = _mm_add_epi16(shuffle, m2);
p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
shuffle = _mm_add_epi16(shuffle, m2);
p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
shuffle = _mm_add_epi16(shuffle, m2);
p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
p0 = _mm_add_epi32(p0, p1);
p2 = _mm_add_epi32(p2, p3);
p4 = _mm_add_epi32(p4, p5);
p6 = _mm_add_epi32(p6, p7);
// isum in 32bits*4*2
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
}
// sumf += dall * isum - dmin * summs in 32bits
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
}
*s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic
float sumf = 0;
uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
const float dall = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
size_t vl = 16;
vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
vl = 32;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
uint8_t is=0;
int isum=0;
for (int j = 0; j < QK_K/128; ++j) {
// load Q2
vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
// duplicate scale elements for product
vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
// load Q8
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
isum += __riscv_vmv_x_s_i32m1_i32(isum1);
q2+=32; q8+=128; is=8;
}
sumf += dall * isum;
}
*s = sumf;
#else
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
int summs = 0;
for (int j = 0; j < 16; ++j) {
summs += y[i].bsums[j] * (sc[j] >> 4);
}
const float dall = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
int isum = 0;
int is = 0;
int d;
for (int k = 0; k < QK_K/128; ++k) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
d = sc[is++] & 0xF;
int isuml = 0;
for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
isum += d * isuml;
d = sc[is++] & 0xF;
isuml = 0;
for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
isum += d * isuml;
shift += 2;
q8 += 32;
}
q2 += 32;
}
sumf += dall * isum - dmin * summs;
}
*s = sumf;
#endif
}
#else
static void ggml_v3_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const block_q2_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m3 = vdupq_n_u8(0x3);
const int32x4_t vzero = vdupq_n_s32(0);
ggml_v3_int8x16x4_t q2bytes;
uint32_t aux32[2];
const uint8_t * scales = (const uint8_t *)aux32;
float sum = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * (float)x[i].d;
const float dmin = -y[i].d * (float)x[i].dmin;
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
aux32[0] = sc[0] & 0x0f0f0f0f;
aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
int isum1 = 0, isum2 = 0;
const uint8x16_t q2bits = vld1q_u8(q2);
const ggml_v3_int8x16x4_t q8bytes = ggml_v3_vld1q_s8_x4(q8);
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
isum1 += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
isum2 += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
isum1 += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
isum2 += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
sum += d * (isum1 + isum2);
}
*s = sum;
#elif defined __AVX2__
const __m256i m3 = _mm256_set1_epi8(3);
__m256 acc = _mm256_setzero_ps();
uint32_t ud, um;
const uint8_t * restrict db = (const uint8_t *)&ud;
const uint8_t * restrict mb = (const uint8_t *)&um;
float summs = 0;
// TODO: optimize this
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
ud = (sc[0] >> 0) & 0x0f0f0f0f;
um = (sc[0] >> 4) & 0x0f0f0f0f;
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
summs += dmin * smin;
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
}
*s = hsum_float_8(acc) + summs;
#elif defined __AVX__
const __m128i m3 = _mm_set1_epi8(3);
__m256 acc = _mm256_setzero_ps();
uint32_t ud, um;
const uint8_t * restrict db = (const uint8_t *)&ud;
const uint8_t * restrict mb = (const uint8_t *)&um;
float summs = 0;
// TODO: optimize this
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
ud = (sc[0] >> 0) & 0x0f0f0f0f;
um = (sc[0] >> 4) & 0x0f0f0f0f;
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
summs += dmin * smin;
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
}
*s = hsum_float_8(acc) + summs;
#elif defined __riscv_v_intrinsic
uint32_t aux32[2];
const uint8_t * scales = (const uint8_t *)aux32;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * (float)x[i].d;
const float dmin = -y[i].d * (float)x[i].dmin;
const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
aux32[0] = sc[0] & 0x0f0f0f0f;
aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
int isum1 = 0;
int isum2 = 0;
size_t vl = 16;
vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
// load Q2
vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl);
vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl));
vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl));
vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl));
vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl));
// load Q8, and take product with Q2
vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl);
vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl);
vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl);
vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl);
isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0];
isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1];
isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2];
isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3];
sumf += d * (isum1 + isum2);
}
*s = sumf;
#else
float sumf = 0;
int isum[4];
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
int summs = 0;
for (int j = 0; j < QK_K/16; ++j) {
summs += y[i].bsums[j] * (sc[j] >> 4);
}
const float dall = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
isum[0] = isum[1] = isum[2] = isum[3] = 0;
for (int l = 0; l < 16; ++l) {
isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
}
for (int l = 0; l < 4; ++l) {
isum[l] *= (sc[l] & 0xF);
}
sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
}
*s = sumf;
#endif
}
#endif
#if QK_K == 256
static void ggml_v3_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
const block_q3_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
uint32_t aux[3];
uint32_t utmp[4];
const uint8x16_t m3b = vdupq_n_u8(0x3);
const int32x4_t vzero = vdupq_n_s32(0);
const uint8x16_t m0 = vdupq_n_u8(1);
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
const uint8x16_t m2 = vshlq_n_u8(m0, 2);
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
const int8_t m32 = 32;
ggml_v3_int8x16x4_t q3bytes;
float sum = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q3 = x[i].qs;
const uint8_t * restrict qh = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
ggml_v3_uint8x16x2_t qhbits = ggml_v3_vld1q_u8_x2(qh);
ggml_v3_uint8x16x4_t q3h;
int32_t isum = 0;
// Set up scales
memcpy(aux, x[i].scales, 12);
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
int8_t * scale = (int8_t *)utmp;
for (int j = 0; j < 16; ++j) scale[j] -= m32;
for (int j = 0; j < QK_K/128; ++j) {
const ggml_v3_uint8x16x2_t q3bits = ggml_v3_vld1q_u8_x2(q3); q3 += 32;
const ggml_v3_int8x16x4_t q8bytes_1 = ggml_v3_vld1q_s8_x4(q8); q8 += 64;
const ggml_v3_int8x16x4_t q8bytes_2 = ggml_v3_vld1q_s8_x4(q8); q8 += 64;
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
scale += 4;
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
scale += 4;
if (j == 0) {
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
}
}
sum += d * isum;
}
*s = sum;
#elif defined __AVX2__
const __m256i m3 = _mm256_set1_epi8(3);
const __m256i mone = _mm256_set1_epi8(1);
const __m128i m32 = _mm_set1_epi8(32);
__m256 acc = _mm256_setzero_ps();
uint32_t aux[3];
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
// Set up scales
memcpy(aux, x[i].scales, 12);
__m128i scales128 = _mm_set_epi32(
((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
scales128 = _mm_sub_epi8(scales128, m32);
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
// high bit
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
// integer accumulator
__m256i sumi = _mm256_setzero_si256();
int bit = 0;
int is = 0;
for (int j = 0; j < QK_K/128; ++j) {
// load low 2 bits
const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
// prepare low and high bits
const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
++bit;
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
++bit;
const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
++bit;
const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
++bit;
// load Q8 quants
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
// and 2 if the high bit was set)
__m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
__m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
__m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
__m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
__m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
__m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
__m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
__m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
// multiply with scales
p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
// accumulate
p16_0 = _mm256_add_epi32(p16_0, p16_1);
p16_2 = _mm256_add_epi32(p16_2, p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
}
// multiply with block scale and accumulate
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
*s = hsum_float_8(acc);
#elif defined __AVX__
const __m128i m3 = _mm_set1_epi8(3);
const __m128i mone = _mm_set1_epi8(1);
const __m128i m32 = _mm_set1_epi8(32);
const __m128i m2 = _mm_set1_epi8(2);
__m256 acc = _mm256_setzero_ps();
const uint32_t *aux;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
// Set up scales
aux = (const uint32_t *)x[i].scales;
__m128i scales128 = _mm_set_epi32(
((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
scales128 = _mm_sub_epi8(scales128, m32);
const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
const __m128i scales[2] = { scales_0, scales_1 };
// high bit *128*2 from block_q3_K.hmask[QK_K/8]
const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
// integer accumulator
__m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128();
for (int j = 0; j < QK_K/128; ++j) {
// load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
// prepare low and high bits
const int bit = j << 2;
const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
// load Q8 quants from block_q8_K.qs[QK_K]
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
// and 2 if the high bit was set)
__m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
__m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
__m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
__m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
__m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
__m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
__m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
__m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
__m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
__m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
__m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
__m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
// multiply with scales
__m128i shuffle = _mm_set1_epi16(0x0100);
p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
shuffle = _mm_add_epi16(shuffle, m2);
p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
shuffle = _mm_add_epi16(shuffle, m2);
p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
shuffle = _mm_add_epi16(shuffle, m2);
p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
shuffle = _mm_add_epi16(shuffle, m2);
p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
shuffle = _mm_add_epi16(shuffle, m2);
p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
shuffle = _mm_add_epi16(shuffle, m2);
p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
shuffle = _mm_add_epi16(shuffle, m2);
p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
// accumulate
p16_0 = _mm_add_epi32(p16_0, p16_1);
p16_2 = _mm_add_epi32(p16_2, p16_3);
p16_4 = _mm_add_epi32(p16_4, p16_5);
p16_6 = _mm_add_epi32(p16_6, p16_7);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
}
// multiply with block scale and accumulate
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
}
*s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic
uint32_t aux[3];
uint32_t utmp[4];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q3 = x[i].qs;
const uint8_t * restrict qh = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
memcpy(aux, x[i].scales, 12);
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
int8_t * scale = (int8_t *)utmp;
for (int j = 0; j < 16; ++j) scale[j] -= 32;
size_t vl = 32;
uint8_t m = 1;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
int sum_t = 0;
for (int j = 0; j < QK_K; j += 128) {
vl = 32;
// load Q3
vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
// compute mask for subtraction
vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl);
m <<= 1;
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl);
m <<= 1;
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl);
m <<= 1;
vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl);
m <<= 1;
// load Q8 and take product with Q3
vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
vl = 16;
// retrieve lane to multiply with scale
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
q3 += 32; q8 += 128; scale += 8;
}
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
sumf += d*sum_t;
}
*s = sumf;
#else
// scalar version
// This function is written like this so the compiler can manage to vectorize most of it
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
// The ideal situation would be if we could just write the code once, and the compiler would
// automatically produce the best possible set of machine instructions, instead of us having to manually
// write vectorized versions for AVX, ARM_NEON, etc.
int8_t aux8[QK_K];
int16_t aux16[8];
float sums [8];
int32_t aux32[8];
memset(sums, 0, 8*sizeof(float));
uint32_t auxs[4];
const int8_t * scales = (const int8_t*)auxs;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q3 = x[i].qs;
const uint8_t * restrict hm = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
uint8_t m = 1;
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
a += 32; m <<= 1;
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
a += 32; m <<= 1;
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
a += 32; m <<= 1;
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
a += 32; m <<= 1;
q3 += 32;
}
a = aux8;
memcpy(auxs, x[i].scales, 12);
uint32_t tmp = auxs[2];
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
for (int j = 0; j < QK_K/16; ++j) {
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
q8 += 8; a += 8;
}
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#else
static void ggml_v3_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_q3_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
const int32x4_t vzero = vdupq_n_s32(0);
const uint8x16_t m3b = vdupq_n_u8(0x3);
const uint8x16_t mh = vdupq_n_u8(4);
ggml_v3_int8x16x4_t q3bytes;
uint16_t aux16[2];
int8_t * scales = (int8_t *)aux16;
float sum = 0;
for (int i = 0; i < nb; ++i) {
ggml_v3_uint8x16x4_t q3h;
const uint8x8_t hbits = vld1_u8(x[i].hmask);
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
const ggml_v3_int8x16x4_t q8bytes = ggml_v3_vld1q_s8_x4(y[i].qs);
const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f;
aux16[1] = (a >> 4) & 0x0f0f;
for (int j = 0; j < 4; ++j) scales[j] -= 8;
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
const float d = y[i].d * (float)x[i].d;
const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
q3h.val[1] = vandq_u8(mh, htmp);
q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
sum += d * isum;
}
*s = sum;
#elif defined __AVX2__
const __m256i m3 = _mm256_set1_epi8(3);
const __m256i m1 = _mm256_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
uint64_t aux64;
uint16_t aux16[2];
const int8_t * aux8 = (const int8_t *)aux16;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f;
aux16[1] = (a >> 4) & 0x0f0f;
const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
memcpy(&aux64, x[i].hmask, 8);
const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
__m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
// load low 2 bits
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
// prepare low and high bits
const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
// load Q8 quants
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
// and 2 if the high bit was set)
const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
__m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
__m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
// multiply with scales
p16_0 = _mm256_madd_epi16(scale_0, p16_0);
p16_1 = _mm256_madd_epi16(scale_1, p16_1);
p16_0 = _mm256_add_epi32(p16_0, p16_1);
// multiply with block scale and accumulate
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
}
*s = hsum_float_8(acc);
#elif defined __AVX__
const __m128i m3 = _mm_set1_epi8(3);
const __m128i m1 = _mm_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
uint64_t aux64;
uint16_t aux16[2];
const int8_t * aux8 = (const int8_t *)aux16;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f;
aux16[1] = (a >> 4) & 0x0f0f;
const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
memcpy(&aux64, x[i].hmask, 8);
__m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
__m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
__m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
__m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
// load low 2 bits
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
// prepare low and high bits
const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
// load Q8 quants
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
// and 2 if the high bit was set)
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
// multiply with scales
p16_0 = _mm_madd_epi16(scale_0, p16_0);
p16_1 = _mm_madd_epi16(scale_1, p16_1);
p16_2 = _mm_madd_epi16(scale_2, p16_2);
p16_3 = _mm_madd_epi16(scale_3, p16_3);
p16_0 = _mm_add_epi32(p16_0, p16_2);
p16_1 = _mm_add_epi32(p16_1, p16_3);
__m256i p16 = MM256_SET_M128I(p16_1, p16_0);
// multiply with block scale and accumulate
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
}
*s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic
uint16_t aux16[2];
int8_t * scales = (int8_t *)aux16;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f;
aux16[1] = (a >> 4) & 0x0f0f;
for (int j = 0; j < 4; ++j) scales[j] -= 8;
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
const float d = y[i].d * (float)x[i].d;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
// load qh
vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8);
vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
size_t vl = 16;
// extend and combine both qh_x1 and qh_x2
vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl);
vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl);
// load Q3
vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl);
vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl);
vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl);
vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl);
vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl);
vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0);
vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1);
vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2);
vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3);
// load Q8 and take product with Q3
vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0];
isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2];
isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1];
isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3];
sumf += d * isum;
}
*s = sumf;
#else
int8_t aux8[QK_K];
int16_t aux16[8];
float sums [8];
int32_t aux32[8];
int32_t scales[4];
memset(sums, 0, 8*sizeof(float));
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q3 = x[i].qs;
const uint8_t * restrict hm = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
int8_t * restrict a = aux8;
for (int l = 0; l < 8; ++l) {
a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
}
scales[0] = (x[i].scales[0] & 0xF) - 8;
scales[1] = (x[i].scales[0] >> 4) - 8;
scales[2] = (x[i].scales[1] & 0xF) - 8;
scales[3] = (x[i].scales[1] >> 4) - 8;
memset(aux32, 0, 8*sizeof(int32_t));
for (int j = 0; j < QK_K/16; ++j) {
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
}
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#endif
#if QK_K == 256
static void ggml_v3_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
uint32_t utmp[4];
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0);
ggml_v3_int8x16x2_t q4bytes;
ggml_v3_int8x16x2_t q8bytes;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
memcpy(utmp, x[i].scales, 12);
uint32x2_t mins8 = { 0 };
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1;
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
sumf -= dmin * vaddvq_s32(prod);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
int32_t sumi1 = 0;
int32_t sumi2 = 0;
for (int j = 0; j < QK_K/64; ++j) {
const ggml_v3_uint8x16x2_t q4bits = ggml_v3_vld1q_u8_x2(q4); q4 += 32;
q8bytes = ggml_v3_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
q8bytes = ggml_v3_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int32x4_t p2 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
__m256 acc = _mm256_setzero_ps();
__m128 acc_m = _mm_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
const __m256i scales = MM256_SET_M128I(sc128, sc128);
__m256i sumi = _mm256_setzero_si256();
for (int j = 0; j < QK_K/64; ++j) {
const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
const __m256i q4l = _mm256_and_si256(q4bits, m4);
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
__m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
p16l = _mm256_madd_epi16(scale_l, p16l);
const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
__m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
p16h = _mm256_madd_epi16(scale_h, p16h);
const __m256i sumj = _mm256_add_epi32(p16l, p16h);
sumi = _mm256_add_epi32(sumi, sumj);
}
__m256 vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i m2 = _mm_set1_epi8(0x2);
__m256 acc = _mm256_setzero_ps();
__m128 acc_m = _mm_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
const __m128i scales = _mm_cvtepu8_epi16(utmps);
const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
const __m128i prod = _mm_madd_epi16(mins, q8s);
acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
__m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128();
__m128i shuffle = _mm_set1_epi16(0x0100);
for (int j = 0; j < QK_K/64; ++j) {
const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi16(shuffle, m2);
const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi16(shuffle, m2);
__m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
p16l = _mm_madd_epi16(scale_l, p16l);
sumi_0 = _mm_add_epi32(sumi_0, p16l);
const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
p16l = _mm_madd_epi16(scale_l, p16l);
sumi_1 = _mm_add_epi32(sumi_1, p16l);
const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
p16h = _mm_madd_epi16(scale_h, p16h);
sumi_0 = _mm_add_epi32(sumi_0, p16h);
const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
p16h = _mm_madd_epi16(scale_h, p16h);
sumi_1 = _mm_add_epi32(sumi_1, p16h);
}
__m256 vd = _mm256_set1_ps(d);
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
}
acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
#elif defined __riscv_v_intrinsic
const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
size_t vl = 8;
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
vl = 32;
int32_t sum_1 = 0;
int32_t sum_2 = 0;
vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
for (int j = 0; j < QK_K/64; ++j) {
// load Q4
vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
// load Q8 and multiply it with lower Q4 nibble
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
// load Q8 and multiply it with upper Q4 nibble
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
q4 += 32; q8 += 64;
}
sumf += d*(sum_1 + sum_2);
}
*s = sumf;
#else
const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2];
int8_t aux8[QK_K];
int16_t aux16[8];
float sums [8];
int32_t aux32[8];
memset(sums, 0, 8*sizeof(float));
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
for (int j = 0; j < QK_K/64; ++j) {
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
a += 32;
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
a += 32; q4 += 32;
}
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
int sumi = 0;
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
a = aux8;
int is = 0;
for (int j = 0; j < QK_K/32; ++j) {
int32_t scale = scales[is++];
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
}
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
const float dmin = GGML_V3_FP16_TO_FP32(x[i].dmin) * y[i].d;
sumf -= dmin * sumi;
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#else
static void ggml_v3_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0);
float sumf = 0;
ggml_v3_int8x16x2_t q4bytes;
ggml_v3_int8x16x4_t q8bytes;
float sum_mins = 0.f;
uint16_t aux16[2];
const uint8_t * restrict scales = (const uint8_t *)aux16;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint16_t * restrict a = (const uint16_t *)x[i].scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
sum_mins += y[i].d * (float)x[i].d[1] * summi;
const float d = y[i].d * (float)x[i].d[0];
const ggml_v3_uint8x16x2_t q4bits = ggml_v3_vld1q_u8_x2(q4);
q8bytes = ggml_v3_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int32x4_t p2 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
sumf += d * (sumi1 + sumi2);
}
*s = sumf - sum_mins;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
__m256 acc = _mm256_setzero_ps();
float summs = 0;
uint16_t aux16[2];
const uint8_t * scales = (const uint8_t *)aux16;
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d[0]) * y[i].d;
const float m = GGML_V3_FP16_TO_FP32(x[i].d[1]) * y[i].d;
const __m256 vd = _mm256_set1_ps(d);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
const __m256i q4l = _mm256_and_si256(q4bits, m4);
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
}
*s = hsum_float_8(acc) - summs;
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
__m256 acc = _mm256_setzero_ps();
float summs = 0;
uint16_t aux16[2];
const uint8_t * scales = (const uint8_t *)aux16;
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d[0]) * y[i].d;
const float m = GGML_V3_FP16_TO_FP32(x[i].d[1]) * y[i].d;
const __m256 vd = _mm256_set1_ps(d);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
}
*s = hsum_float_8(acc) - summs;
#elif defined __riscv_v_intrinsic
uint16_t s16[2];
const uint8_t * restrict scales = (const uint8_t *)s16;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint16_t * restrict b = (const uint16_t *)x[i].scales;
s16[0] = b[0] & 0x0f0f;
s16[1] = (b[0] >> 4) & 0x0f0f;
sumf -= y[i].d * GGML_V3_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d[0]);
size_t vl = 32;
vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
// load Q4
vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
// load Q8 and multiply it with lower Q4 nibble
vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl);
vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl);
sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1);
// load Q8 and multiply it with upper Q4 nibble
vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl);
vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl);
sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2);
}
*s = sumf;
#else
uint8_t aux8[QK_K];
int16_t aux16[16];
float sums [8];
memset(sums, 0, 8*sizeof(float));
uint16_t s16[2];
const uint8_t * restrict scales = (const uint8_t *)s16;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
uint8_t * restrict a = aux8;
for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
const uint16_t * restrict b = (const uint16_t *)x[i].scales;
s16[0] = b[0] & 0x0f0f;
s16[1] = (b[0] >> 4) & 0x0f0f;
sumf -= y[i].d * GGML_V3_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d[0]);
for (int j = 0; j < QK_K/32; ++j) {
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
q8 += 16; a += 16;
for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
q8 += 16; a += 16;
const float dl = d * scales[j];
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
}
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#endif
#if QK_K == 256
static void ggml_v3_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
uint32_t utmp[4];
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
const int32x4_t mzero = vdupq_n_s32(0);
ggml_v3_int8x16x4_t q5bytes;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
int32_t sumi_mins = vaddvq_s32(prod);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
ggml_v3_uint8x16x2_t qhbits = ggml_v3_vld1q_u8_x2(qh);
ggml_v3_uint8x16x4_t q5h;
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
const ggml_v3_uint8x16x2_t q5bits = ggml_v3_vld1q_u8_x2(q5); q5 += 32;
const ggml_v3_int8x16x4_t q8bytes = ggml_v3_vld1q_s8_x4(q8); q8 += 64;
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
sumi += vaddvq_s32(ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
sumi += vaddvq_s32(ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
}
sumf += d * sumi - dmin * sumi_mins;
}
*s = sumf;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
const __m128i mzero = _mm_setzero_si128();
const __m256i mone = _mm256_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
float summs = 0.f;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
#if QK_K == 256
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
#else
// TODO
const float d = 0, dmin = 0;
#endif
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
summs += dmin * _mm_extract_epi32(hsum, 0);
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
const __m256i scales = MM256_SET_M128I(sc128, sc128);
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
__m256i hmask = mone;
__m256i sumi = _mm256_setzero_si256();
int bit = 0;
for (int j = 0; j < QK_K/64; ++j) {
const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
hmask = _mm256_slli_epi16(hmask, 1);
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
hmask = _mm256_slli_epi16(hmask, 1);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
__m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
__m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
p16_0 = _mm256_madd_epi16(scale_0, p16_0);
p16_1 = _mm256_madd_epi16(scale_1, p16_1);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
}
__m256 vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
*s = hsum_float_8(acc) + summs;
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i mzero = _mm_setzero_si128();
const __m128i mone = _mm_set1_epi8(1);
const __m128i m2 = _mm_set1_epi8(2);
__m256 acc = _mm256_setzero_ps();
float summs = 0.f;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_V3_FP16_TO_FP32(x[i].dmin);
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
const __m128i scales = _mm_cvtepu8_epi16(utmps);
const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
const __m128i prod = _mm_madd_epi16(mins, q8s);
const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
summs += dmin * _mm_extract_epi32(hsum, 0);
const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
__m128i hmask = mone;
__m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128();
int bit = 0;
__m128i shuffle = _mm_set1_epi16(0x0100);
for (int j = 0; j < QK_K/64; ++j) {
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi16(shuffle, m2);
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi16(shuffle, m2);
const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
__m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
__m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
__m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
__m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
__m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
__m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
hmask = _mm_slli_epi16(hmask, 1);
__m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
__m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
p16_0 = _mm_madd_epi16(scale_0, p16_0);
p16_1 = _mm_madd_epi16(scale_0, p16_1);
q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
q5_0 = _mm_add_epi8(q5l_0, q5h_0);
q5_1 = _mm_add_epi8(q5l_1, q5h_1);
hmask = _mm_slli_epi16(hmask, 1);
q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
__m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
p16_2 = _mm_madd_epi16(scale_1, p16_2);
p16_3 = _mm_madd_epi16(scale_1, p16_3);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
}
__m256 vd = _mm256_set1_ps(d);
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
}
*s = hsum_float_8(acc) + summs;
#elif defined __riscv_v_intrinsic
const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2];
float sumf = 0;
float sums = 0.0;
size_t vl;
for (int i = 0; i < nb; ++i) {
vl = 8;
const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
const float dmin = GGML_V3_FP16_TO_FP32(x[i].dmin) * y[i].d;
vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
vl = 32;
int32_t aux32 = 0;
int is = 0;
uint8_t m = 1;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
for (int j = 0; j < QK_K/64; ++j) {
// load Q5 and Q8
vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
// compute mask for addition
vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl);
m <<= 1;
vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl);
m <<= 1;
vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
q5 += 32; q8 += 64;
}
vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
}
*s = sumf+sums;
#else
const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2];
int8_t aux8[QK_K];
int16_t aux16[8];
float sums [8];
int32_t aux32[8];
memset(sums, 0, 8*sizeof(float));
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs;
const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
uint8_t m = 1;
for (int j = 0; j < QK_K/64; ++j) {
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
a += 32; m <<= 1;
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
a += 32; m <<= 1;
q4 += 32;
}
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
int sumi = 0;
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
a = aux8;
int is = 0;
for (int j = 0; j < QK_K/32; ++j) {
int32_t scale = scales[is++];
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
}
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
const float dmin = GGML_V3_FP16_TO_FP32(x[i].dmin) * y[i].d;
sumf -= dmin * sumi;
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#else
static void ggml_v3_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t mh = vdupq_n_u8(16);
const int32x4_t mzero = vdupq_n_s32(0);
ggml_v3_int8x16x4_t q5bytes;
ggml_v3_uint8x16x4_t q5h;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * (float)x[i].d;
const int8_t * sc = x[i].scales;
const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const uint8x8_t qhbits = vld1_u8(qh);
const ggml_v3_uint8x16x2_t q5bits = ggml_v3_vld1q_u8_x2(q5);
const ggml_v3_int8x16x4_t q8bytes = ggml_v3_vld1q_s8_x4(q8);
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
q5h.val[2] = vbicq_u8(mh, htmp);
q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
int32_t sumi1 = sc[0] * vaddvq_s32(ggml_v3_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
int32_t sumi2 = sc[1] * vaddvq_s32(ggml_v3_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
int32_t sumi3 = sc[2] * vaddvq_s32(ggml_v3_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
int32_t sumi4 = sc[3] * vaddvq_s32(ggml_v3_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
}
*s = sumf;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
const __m256i mone = _mm256_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
int64_t aux64;
memcpy(&aux64, x[i].qh, 8);
const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
}
*s = hsum_float_8(acc);
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i mone = _mm_set1_epi8(1);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
int64_t aux64;
memcpy(&aux64, x[i].qh, 8);
const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
}
*s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * (float)x[i].d;
const int8_t * sc = x[i].scales;
const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
// load qh
vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8);
vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
size_t vl = 16;
// combine both qh_1 and qh_2
vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl);
vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl);
vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0);
vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1);
vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2);
vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3);
// load q5
vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl);
vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl);
vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl));
vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl));
vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl));
vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl));
vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl);
vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl);
vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl);
vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl);
// load Q8 and multiply it with Q5
vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0);
int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1);
int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2);
int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3);
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
}
*s = sumf;
#else
int8_t aux8[QK_K];
int16_t aux16[16];
float sums [8];
memset(sums, 0, 8*sizeof(float));
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].qs;
const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
int8_t * restrict a = aux8;
for (int l = 0; l < 32; ++l) {
a[l+ 0] = q4[l] & 0xF;
a[l+32] = q4[l] >> 4;
}
for (int is = 0; is < 8; ++is) {
uint8_t m = 1 << is;
for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
}
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const int8_t * restrict sc = x[i].scales;
for (int j = 0; j < QK_K/16; ++j) {
const float dl = d * sc[j];
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
q8 += 16; a += 16;
}
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#endif
#if QK_K == 256
static void ggml_v3_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_q6_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF);
const int32x4_t vzero = vdupq_n_s32(0);
//const int8x16_t m32s = vdupq_n_s8(32);
const uint8x16_t mone = vdupq_n_u8(3);
ggml_v3_int8x16x4_t q6bytes;
ggml_v3_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) {
const float d_all = GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const int8_t * restrict scale = x[i].scales;
const ggml_v3_int16x8x2_t q8sums = ggml_v3_vld1q_s16_x2(y[i].bsums);
const int8x16_t scales = vld1q_s8(scale);
const ggml_v3_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
int32_t isum_mins = vaddvq_s32(prod);
int32_t isum = 0;
for (int j = 0; j < QK_K/128; ++j) {
ggml_v3_uint8x16x2_t qhbits = ggml_v3_vld1q_u8_x2(qh); qh += 32;
ggml_v3_uint8x16x4_t q6bits = ggml_v3_vld1q_u8_x4(q6); q6 += 64;
ggml_v3_int8x16x4_t q8bytes = ggml_v3_vld1q_s8_x4(q8); q8 += 64;
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits.val[1], 2);
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
scale += 4;
q8bytes = ggml_v3_vld1q_s8_x4(q8); q8 += 64;
shifted = vshrq_n_u8(qhbits.val[0], 4);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits.val[1], 4);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits.val[0], 6);
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits.val[1], 6);
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
scale += 4;
}
//sum += isum * d_all * y[i].d;
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
*s = sum;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
const __m256i m2 = _mm256_set1_epi8(3);
const __m256i m32s = _mm256_set1_epi8(32);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q4 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
__m256i sumi = _mm256_setzero_si256();
int is = 0;
for (int j = 0; j < QK_K/128; ++j) {
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
is += 4;
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
__m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
__m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
__m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
__m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
__m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
__m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
__m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
__m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
*s = hsum_float_8(acc);
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i m3 = _mm_set1_epi8(3);
const __m128i m32s = _mm_set1_epi8(32);
const __m128i m2 = _mm_set1_epi8(2);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q4 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
__m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128();
__m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
for (int j = 0; j < QK_K/128; ++j) {
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
__m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
__m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
__m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
}
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
}
*s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const int8_t * restrict scale = x[i].scales;
size_t vl;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
int sum_t = 0;
int is = 0;
for (int j = 0; j < QK_K/128; ++j) {
vl = 32;
// load qh
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
// load Q6
vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
// load Q8 and take product
vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
vl = 16;
vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
q6 += 64; qh += 32; q8 += 128; is=8;
}
sumf += d * sum_t;
}
*s = sumf;
#else
int8_t aux8[QK_K];
int16_t aux16[8];
float sums [8];
int32_t aux32[8];
memset(sums, 0, 8*sizeof(float));
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
}
a += 128;
q4 += 64;
qh += 32;
}
a = aux8;
int is = 0;
for (int j = 0; j < QK_K/16; ++j) {
int scale = x[i].scales[is++];
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
}
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#else
static void ggml_v3_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_q6_K * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#ifdef __ARM_NEON
float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF);
const int8x16_t m32s = vdupq_n_s8(32);
const int32x4_t vzero = vdupq_n_s32(0);
const uint8x16_t mone = vdupq_n_u8(3);
ggml_v3_int8x16x4_t q6bytes;
ggml_v3_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) {
const float d_all = (float)x[i].d;
const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const int8_t * restrict scale = x[i].scales;
int32_t isum = 0;
uint8x16_t qhbits = vld1q_u8(qh);
ggml_v3_uint8x16x2_t q6bits = ggml_v3_vld1q_u8_x2(q6);
ggml_v3_int8x16x4_t q8bytes = ggml_v3_vld1q_s8_x4(q8);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits, 4);
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits, 6);
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
isum += vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(ggml_v3_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
sum += isum * d_all * y[i].d;
}
*s = sum;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
const __m256i m2 = _mm256_set1_epi8(3);
const __m256i m32s = _mm256_set1_epi8(32);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q4 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
__m256i sumi = _mm256_setzero_si256();
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
__m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
__m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
__m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
__m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
*s = hsum_float_8(acc);
#elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i m2 = _mm_set1_epi8(3);
const __m128i m32s = _mm_set1_epi8(32);
__m256 acc = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_V3_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q4 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
__m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128();
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
__m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
__m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
__m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
__m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
__m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
__m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
__m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
__m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
}
*s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d_all = (float)x[i].d;
const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
const int8_t * restrict scale = x[i].scales;
int32_t isum = 0;
size_t vl = 16;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
// load Q6
vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl);
vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl);
// load qh
vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl);
vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl);
vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl);
vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl);
vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl);
vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl);
vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl);
vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl);
vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl);
// load Q8 and take product
vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0];
isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1];
isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2];
isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3];
sumf += isum * d_all * y[i].d;
}
*s = sumf;
#else
int8_t aux8[QK_K];
int16_t aux16[8];
float sums [8];
int32_t aux32[8];
memset(sums, 0, 8*sizeof(float));
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q4 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
memset(aux32, 0, 8*sizeof(int32_t));
int8_t * restrict a = aux8;
for (int l = 0; l < 16; ++l) {
a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
}
int is = 0;
for (int j = 0; j < QK_K/16; ++j) {
int scale = x[i].scales[is++];
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
q8 += 8; a += 8;
}
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
}
for (int l = 0; l < 8; ++l) sumf += sums[l];
*s = sumf;
#endif
}
#endif
static const int8_t keven_signs_q2xs[1024] = {
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
};
static void ggml_v3_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_iq2_xxs * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#if defined(__ARM_NEON)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
ggml_v3_int8x16x4_t q2u;
ggml_v3_int8x16x4_t q2s;
ggml_v3_int8x16x4_t q8b;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
float sumf1 = 0, sumf2 = 0;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
q8b = ggml_v3_vld1q_s8_x4(q8); q8 += 64;
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
const int32x4_t p1 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
const int32x4_t p2 = ggml_v3_vdotq_s32(ggml_v3_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
}
sumf += d*(sumf1 + sumf2);
}
*s = 0.25f * sumf;
#elif defined(__AVX2__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
__m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const uint16_t ls1 = aux32[1] >> 28;
const uint16_t ls2 = aux32[3] >> 28;
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
sumi1 = _mm256_add_epi32(sumi1, p1);
sumi2 = _mm256_add_epi32(sumi2, p2);
}
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
#else
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
float sumf = 0.f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
int32_t bsum = 0;
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
memcpy(aux32, q2, 2*sizeof(uint32_t));
q2 += 4;
const uint32_t ls = 2*(aux32[1] >> 28) + 1;
int32_t sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
}
bsum += sumi * ls;
}
sumf += d * bsum;
}
*s = 0.125f * sumf;
#endif
}
static void ggml_v3_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
const block_iq2_xs * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#if defined(__ARM_NEON)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
int8x16x4_t q2u;
int8x16x4_t q2s;
int8x16x4_t q8b;
int32x4x4_t scales32;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const uint8x8_t scales8 = vld1_u8(x[i].scales);
const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
int32x4_t sumi = vdupq_n_s32(0);
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
q8b = vld1q_s8_x4(q8); q8 += 64;
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
const int32x4_t p1 = ggml_v3_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
const int32x4_t p2 = ggml_v3_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
const int32x4_t p3 = ggml_v3_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
const int32x4_t p4 = ggml_v3_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
q2 += 8;
}
sumf += d*vaddvq_s32(sumi);
}
*s = 0.125f * sumf;
#elif defined(__AVX2__)
const __m128i m4 = _mm_set1_epi8(0xf);
const __m128i m1 = _mm_set1_epi8(1);
const __m128i m511 = _mm_set1_epi16(511);
const __m128i m127 = _mm_set1_epi16(127);
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint64_t aux64;
// somewhat hacky, but gives a significant boost in performance
__m128i aux_gindex, aux_sindex;
const uint16_t * gindex = (const uint16_t *)&aux_gindex;
const uint16_t * sindex = (const uint16_t *)&aux_sindex;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
memcpy(&aux64, x[i].scales, 8);
__m128i stmp = _mm_set1_epi64x(aux64);
stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
__m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8;
aux_gindex = _mm_and_si128(q2_data, m511);
aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127);
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]);
const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]);
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
}
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
#else
float sumf = 0.f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_V3_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const uint8_t * restrict sc = x[i].scales;
const int8_t * restrict q8 = y[i].qs;
int32_t bsum = 0;
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
int32_t sumi = 0;
for (int l = 0; l < 2; ++l) {
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
for (int j = 0; j < 8; ++j) {
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
}
bsum += sumi * ls1;
sumi = 0;
for (int l = 2; l < 4; ++l) {
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
for (int j = 0; j < 8; ++j) {
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
}
bsum += sumi * ls2;
q2 += 4;
}
sumf += d * bsum;
}
*s = 0.125f * sumf;
#endif
}