mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-18 23:49:46 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .devops/vulkan.Dockerfile # .github/workflows/build-self-hosted.yml # .github/workflows/build.yml # .github/workflows/release.yml # .github/workflows/server-self-hosted.yml # docs/build.md # ggml/src/ggml-hexagon/htp/CMakeLists.txt # ggml/src/ggml-hexagon/htp/hex-utils.h # ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c # ggml/src/ggml-hexagon/htp/hmx-utils.h # ggml/src/ggml-hexagon/htp/htp-ctx.h # ggml/src/ggml-hexagon/htp/htp-ops.h # ggml/src/ggml-hexagon/htp/hvx-base.h # ggml/src/ggml-hexagon/htp/main.c # ggml/src/ggml-webgpu/ggml-webgpu.cpp # tests/test-backend-ops.cpp # tests/test-mtmd-c-api.c
This commit is contained in:
commit
ac29e6f0c0
32 changed files with 553 additions and 250 deletions
|
|
@ -348,6 +348,53 @@ extern "C" {
|
|||
// Set a callback to be called for each resulting node during graph compute
|
||||
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
|
||||
|
||||
//
|
||||
// Meta backend
|
||||
//
|
||||
|
||||
#define GGML_BACKEND_META_MAX_DEVICES 16
|
||||
|
||||
enum ggml_backend_meta_split_axis {
|
||||
// tensor split by tensor dimensions:
|
||||
GGML_BACKEND_SPLIT_AXIS_0 = 0,
|
||||
GGML_BACKEND_SPLIT_AXIS_1 = 1,
|
||||
GGML_BACKEND_SPLIT_AXIS_2 = 2,
|
||||
GGML_BACKEND_SPLIT_AXIS_3 = 3,
|
||||
|
||||
GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
|
||||
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum
|
||||
|
||||
// for internal bookkeeping only:
|
||||
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
|
||||
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
|
||||
};
|
||||
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);
|
||||
|
||||
struct ggml_backend_meta_split_state {
|
||||
enum ggml_backend_meta_split_axis axis;
|
||||
|
||||
// for tensors with axis >= 0 && axis < GGML_MAX_DIMS:
|
||||
// - each device has a slice of the tensor along the split axis
|
||||
// - most tensors have n_segments == 1 and a contiguous slice of the tensor data
|
||||
// - some tensors have an inhomogenenous data layout along the split axis,
|
||||
// those tensors are divided into segments which are each individually split across devices
|
||||
// - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis,
|
||||
// the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1],
|
||||
// - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments
|
||||
// that each need to be split individually across devices so that each device gets a slice of Q, K, and V
|
||||
int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES];
|
||||
uint32_t n_segments;
|
||||
};
|
||||
|
||||
// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
|
||||
typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
|
||||
// TODO: this looks a bit strange - a backend API creates a device. I think we should try
|
||||
// express this as a backend registry functionality instead
|
||||
GGML_API ggml_backend_dev_t ggml_backend_meta_device(
|
||||
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud);
|
||||
|
||||
//
|
||||
// Utils
|
||||
//
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include "ggml-backend-impl.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <limits.h>
|
||||
#include <stdarg.h>
|
||||
|
|
|
|||
|
|
@ -5,9 +5,6 @@
|
|||
#include "ggml-alloc.h"
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
// TODO: tmp
|
||||
#include "ggml-ext.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
|
|
|||
|
|
@ -1,56 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
// This is a "staging" header for new ggml API
|
||||
// It is not publicly available and it should not be used by 3rd party projects
|
||||
//
|
||||
// When the API matures enough, it will be moved to the official public API
|
||||
|
||||
//
|
||||
// Meta backend
|
||||
//
|
||||
|
||||
#define GGML_BACKEND_META_MAX_DEVICES 16
|
||||
|
||||
enum ggml_backend_meta_split_axis {
|
||||
// tensor split by tensor dimensions:
|
||||
GGML_BACKEND_SPLIT_AXIS_0 = 0,
|
||||
GGML_BACKEND_SPLIT_AXIS_1 = 1,
|
||||
GGML_BACKEND_SPLIT_AXIS_2 = 2,
|
||||
GGML_BACKEND_SPLIT_AXIS_3 = 3,
|
||||
|
||||
GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
|
||||
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum
|
||||
|
||||
// for internal bookkeeping only:
|
||||
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
|
||||
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
|
||||
};
|
||||
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);
|
||||
|
||||
struct ggml_backend_meta_split_state {
|
||||
enum ggml_backend_meta_split_axis axis;
|
||||
|
||||
// for tensors with axis >= 0 && axis < GGML_MAX_DIMS:
|
||||
// - each device has a slice of the tensor along the split axis
|
||||
// - most tensors have n_segments == 1 and a contiguous slice of the tensor data
|
||||
// - some tensors have an inhomogenenous data layout along the split axis,
|
||||
// those tensors are divided into segments which are each individually split across devices
|
||||
// - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis,
|
||||
// the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1],
|
||||
// - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments
|
||||
// that each need to be split individually across devices so that each device gets a slice of Q, K, and V
|
||||
int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES];
|
||||
uint32_t n_segments;
|
||||
};
|
||||
|
||||
// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
|
||||
typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
|
||||
// TODO: this looks a bit strange - a backend API creates a device. I think we should try
|
||||
// express this as a backend registry functionality instead
|
||||
GGML_API ggml_backend_dev_t ggml_backend_meta_device(
|
||||
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud);
|
||||
158
ggml/src/ggml-hexagon/htp/hmx-queue.c
Normal file
158
ggml/src/ggml-hexagon/htp/hmx-queue.c
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <qurt_thread.h>
|
||||
#include <qurt_futex.h>
|
||||
|
||||
#include <HAP_compute_res.h>
|
||||
|
||||
#include "hmx-queue.h"
|
||||
|
||||
#define QURT_LOWEST_PRIO (254)
|
||||
|
||||
static inline void hmx_lock(struct hmx_queue *q)
|
||||
{
|
||||
if (!q->hmx_locked) {
|
||||
HAP_compute_res_hmx_lock(q->hap_rctx);
|
||||
q->hmx_locked = true;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hmx_unlock(struct hmx_queue *q)
|
||||
{
|
||||
if (q->hmx_locked) {
|
||||
HAP_compute_res_hmx_unlock(q->hap_rctx);
|
||||
q->hmx_locked = false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
|
||||
unsigned int ir = atomic_load(&q->idx_read);
|
||||
|
||||
while (ir != atomic_load(&q->idx_write)) {
|
||||
struct hmx_queue_desc *d = &q->desc[ir];
|
||||
if (!d->done) {
|
||||
FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data);
|
||||
|
||||
enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func;
|
||||
switch (sig) {
|
||||
case HMX_QUEUE_NOOP: /* noop */; break;
|
||||
case HMX_QUEUE_KILL: *killed = true; break;
|
||||
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
|
||||
default:
|
||||
hmx_lock(q);
|
||||
d->func(d->data);
|
||||
break;
|
||||
}
|
||||
|
||||
atomic_fetch_add(&d->done, 1);
|
||||
}
|
||||
|
||||
ir = (ir + 1) & q->idx_mask;
|
||||
atomic_store(&q->idx_read, ir);
|
||||
}
|
||||
}
|
||||
|
||||
static void hmx_queue_thread(void * arg) {
|
||||
struct hmx_queue * q = (struct hmx_queue *) arg;
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: started");
|
||||
|
||||
bool killed = false;
|
||||
|
||||
unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT;
|
||||
unsigned int prev_seqn = 0;
|
||||
while (!killed) {
|
||||
unsigned int seqn = atomic_load(&q->seqn);
|
||||
if (seqn == prev_seqn) {
|
||||
if (--poll_cnt) { hex_pause(); continue; }
|
||||
FARF(HIGH, "hmx-queue-thread: sleeping");
|
||||
qurt_futex_wait(&q->seqn, prev_seqn);
|
||||
continue;
|
||||
}
|
||||
prev_seqn = seqn;
|
||||
poll_cnt = HMX_QUEUE_POLL_COUNT;
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: new work");
|
||||
|
||||
hmx_queue_process(q, &killed);
|
||||
}
|
||||
|
||||
FARF(HIGH, "hmx-queue-thread: stopped");
|
||||
}
|
||||
|
||||
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) {
|
||||
capacity = hex_ceil_pow2(capacity);
|
||||
|
||||
struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue));
|
||||
if (q == NULL) {
|
||||
FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
|
||||
return NULL;
|
||||
}
|
||||
memset(q, 0, sizeof(struct hmx_queue));
|
||||
q->capacity = capacity;
|
||||
q->idx_mask = capacity - 1;
|
||||
q->hap_rctx = hap_rctx;
|
||||
|
||||
q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc));
|
||||
if (!q->desc) {
|
||||
FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n");
|
||||
return NULL;
|
||||
}
|
||||
memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc));
|
||||
|
||||
const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE;
|
||||
q->stack = (unsigned char *) memalign(64, stack_size);
|
||||
if (!q->stack) {
|
||||
FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size);
|
||||
return NULL;
|
||||
}
|
||||
memset(q->stack, 0, stack_size);
|
||||
|
||||
// Match caller thread priority (same pattern as worker-pool.c).
|
||||
int prio = qurt_thread_get_priority(qurt_thread_get_id());
|
||||
if (prio < 1) {
|
||||
prio = 1;
|
||||
}
|
||||
if (prio > QURT_LOWEST_PRIO) {
|
||||
prio = QURT_LOWEST_PRIO;
|
||||
}
|
||||
|
||||
qurt_thread_attr_t attr;
|
||||
qurt_thread_attr_init(&attr);
|
||||
qurt_thread_attr_set_stack_addr(&attr, q->stack);
|
||||
qurt_thread_attr_set_stack_size(&attr, stack_size);
|
||||
qurt_thread_attr_set_priority(&attr, prio);
|
||||
qurt_thread_attr_set_name(&attr, "hmx-queue");
|
||||
|
||||
int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q);
|
||||
if (err) {
|
||||
FARF(ERROR, "hmx-worker: thread create failed (%d)", err);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
FARF(HIGH, "hmx-queue: capacity %u\n", capacity);
|
||||
|
||||
return q;
|
||||
}
|
||||
|
||||
void hmx_queue_delete(struct hmx_queue * q) {
|
||||
if (!q) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Tell the worker to exit.
|
||||
hmx_queue_flush(q);
|
||||
hmx_queue_signal(q, HMX_QUEUE_KILL);
|
||||
hmx_queue_flush(q);
|
||||
|
||||
int status;
|
||||
qurt_thread_join(q->thread, &status);
|
||||
|
||||
free(q->desc);
|
||||
free(q->stack);
|
||||
free(q);
|
||||
}
|
||||
134
ggml/src/ggml-hexagon/htp/hmx-queue.h
Normal file
134
ggml/src/ggml-hexagon/htp/hmx-queue.h
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
#ifndef HMX_QUEUE_H
|
||||
#define HMX_QUEUE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdatomic.h>
|
||||
|
||||
#include <hexagon_types.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <qurt_futex.h>
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024)
|
||||
#define HMX_QUEUE_POLL_COUNT 2000
|
||||
|
||||
typedef void (*hmx_queue_func)(void *);
|
||||
|
||||
// Dummy funcs used as signals
|
||||
enum hmx_queue_signal {
|
||||
HMX_QUEUE_NOOP = 0, // aka NULL
|
||||
HMX_QUEUE_SUSPEND,
|
||||
HMX_QUEUE_KILL
|
||||
};
|
||||
|
||||
struct hmx_queue_desc {
|
||||
hmx_queue_func func;
|
||||
void * data;
|
||||
atomic_uint done;
|
||||
};
|
||||
|
||||
struct hmx_queue {
|
||||
struct hmx_queue_desc * desc;
|
||||
atomic_uint idx_write; // updated by producer (push)
|
||||
atomic_uint idx_read; // updated by consumer (process)
|
||||
unsigned int idx_pop; // updated by producer (pop)
|
||||
uint32_t idx_mask;
|
||||
uint32_t capacity;
|
||||
|
||||
atomic_uint seqn; // incremented for all pushes, used with futex
|
||||
qurt_thread_t thread;
|
||||
void * stack;
|
||||
uint32_t hap_rctx;
|
||||
bool hmx_locked;
|
||||
};
|
||||
|
||||
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
|
||||
void hmx_queue_delete(struct hmx_queue * q);
|
||||
|
||||
static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) {
|
||||
struct hmx_queue_desc d = { func, data };
|
||||
return d;
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) {
|
||||
unsigned int ir = atomic_load(&q->idx_read);
|
||||
unsigned int iw = q->idx_write;
|
||||
|
||||
if (((iw + 1) & q->idx_mask) == ir) {
|
||||
FARF(HIGH, "hmx-queue-push: queue is full\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
atomic_store(&d.done, 0);
|
||||
|
||||
FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data);
|
||||
|
||||
q->desc[iw] = d;
|
||||
atomic_store(&q->idx_write, (iw + 1) & q->idx_mask);
|
||||
// wake up our thread
|
||||
atomic_fetch_add(&q->seqn, 1);
|
||||
qurt_futex_wake(&q->seqn, 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) {
|
||||
return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL));
|
||||
}
|
||||
|
||||
static inline bool hmx_queue_empty(struct hmx_queue * q) {
|
||||
return q->idx_pop == q->idx_write;
|
||||
}
|
||||
|
||||
static inline uint32_t hmx_queue_depth(struct hmx_queue * q) {
|
||||
return (q->idx_read - q->idx_read) & q->idx_mask;
|
||||
}
|
||||
|
||||
static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) {
|
||||
return q->capacity;
|
||||
}
|
||||
|
||||
static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) {
|
||||
unsigned int ip = q->idx_pop;
|
||||
unsigned int iw = q->idx_write;
|
||||
|
||||
struct hmx_queue_desc rd = { NULL, NULL };
|
||||
if (ip == iw) {
|
||||
return rd;
|
||||
}
|
||||
|
||||
// Wait for desc to complete
|
||||
struct hmx_queue_desc * d = &q->desc[ip];
|
||||
while (!atomic_load(&d->done)) {
|
||||
FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip);
|
||||
hex_pause();
|
||||
}
|
||||
|
||||
rd = *d;
|
||||
q->idx_pop = (ip + 1) & q->idx_mask;
|
||||
|
||||
FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data);
|
||||
return rd;
|
||||
}
|
||||
|
||||
static inline void hmx_queue_flush(struct hmx_queue * q) {
|
||||
while (hmx_queue_pop(q).func != NULL) ;
|
||||
}
|
||||
|
||||
static inline void hmx_queue_suspend(struct hmx_queue *q) {
|
||||
hmx_queue_signal(q, HMX_QUEUE_SUSPEND);
|
||||
hmx_queue_flush(q);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif /* HMX_QUEUE_H */
|
||||
|
|
@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
|
|||
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
|
||||
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
|
||||
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
|
||||
case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
|
|
|
|||
|
|
@ -1049,6 +1049,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -1165,6 +1166,23 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
if (op->src[1]->type != op->src[2]->type) {
|
||||
return false;
|
||||
}
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
break;
|
||||
case GGML_TYPE_BF16:
|
||||
if (!has_bfloat) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@
|
|||
#define OP_UNARY_NUM_CEIL 118
|
||||
#define OP_UNARY_NUM_ROUND 119
|
||||
#define OP_UNARY_NUM_TRUNC 120
|
||||
#define OP_UNARY_NUM_XIELU 121
|
||||
|
||||
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
|
||||
#define OP_SUM_ROWS_NUM_MEAN 11
|
||||
|
|
|
|||
|
|
@ -787,6 +787,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|||
args.max = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) {
|
||||
args.slope = ggml_get_op_params_f32(op, 1); // alpha_n
|
||||
args.scale = ggml_get_op_params_f32(op, 2); // alpha_p
|
||||
args.bias = ggml_get_op_params_f32(op, 3); // beta
|
||||
args.val = ggml_get_op_params_f32(op, 4); // eps
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
if (pipeline.c4) {
|
||||
|
|
|
|||
|
|
@ -1177,6 +1177,15 @@ kernel void kernel_unary_impl(
|
|||
if (FC_OP == OP_UNARY_NUM_TRUNC) {
|
||||
dst_ptr[i0] = (T) trunc(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_XIELU) {
|
||||
const TC xi = x;
|
||||
const TC gate = TC(xi > TC(0.0f));
|
||||
const TC clamped = fmin(xi, TC(args.val));
|
||||
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
|
||||
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
|
||||
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
|
|
|
|||
|
|
@ -32,6 +32,9 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
|||
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
|
||||
|
||||
#include <vulkan/vulkan.hpp>
|
||||
// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and
|
||||
// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan.
|
||||
#include <spirv-headers/spirv.hpp> //kcpp: use bundled spirv.hpp to ensure same version
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
|
@ -2147,6 +2150,66 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|||
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
|
||||
|
||||
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
|
||||
|
||||
// Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for
|
||||
// separate shader variants compiled with -DRTE16.
|
||||
std::vector<uint32_t> spv;
|
||||
if (device->float_controls_rte_fp16) {
|
||||
const uint32_t* spv_words = reinterpret_cast<const uint32_t *>(spv_data);
|
||||
size_t word_count = spv_size / sizeof(uint32_t);
|
||||
spv.assign(spv_words, spv_words + word_count);
|
||||
|
||||
// Find insertion points respecting SPIR-V layout order:
|
||||
// Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ...
|
||||
size_t pos = 5; // skip header
|
||||
size_t cap_insert_pos = pos;
|
||||
size_t ext_insert_pos = pos;
|
||||
size_t exec_insert_pos = pos;
|
||||
uint32_t entry_point_id = 0;
|
||||
|
||||
while (pos < spv.size()) {
|
||||
uint32_t opcode = spv[pos] & spv::OpCodeMask;
|
||||
uint32_t len = spv[pos] >> spv::WordCountShift;
|
||||
if (len == 0) break;
|
||||
|
||||
if (opcode == spv::OpCapability) {
|
||||
cap_insert_pos = pos + len;
|
||||
ext_insert_pos = pos + len;
|
||||
} else if (opcode == spv::OpExtension) {
|
||||
ext_insert_pos = pos + len;
|
||||
} else if (opcode == spv::OpEntryPoint) {
|
||||
entry_point_id = spv[pos + 2];
|
||||
exec_insert_pos = pos + len;
|
||||
} else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) {
|
||||
exec_insert_pos = pos + len;
|
||||
} else if (entry_point_id != 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
pos += len;
|
||||
}
|
||||
|
||||
// Insert from latest position first so earlier indices stay valid.
|
||||
|
||||
// OpExecutionMode %entrypoint RoundingModeRTE 16
|
||||
uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 };
|
||||
spv.insert(spv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode));
|
||||
|
||||
// OpExtension "SPV_KHR_float_controls"
|
||||
const char ext_str[] = "SPV_KHR_float_controls";
|
||||
size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t));
|
||||
std::vector<uint32_t> extension(1 + ext_str_words, 0);
|
||||
extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension;
|
||||
memcpy(&extension[1], ext_str, sizeof(ext_str));
|
||||
spv.insert(spv.begin() + ext_insert_pos, extension.begin(), extension.end());
|
||||
|
||||
// OpCapability RoundingModeRTE
|
||||
uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE };
|
||||
spv.insert(spv.begin() + cap_insert_pos, std::begin(capability), std::end(capability));
|
||||
|
||||
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spv.size() * sizeof(uint32_t), spv.data());
|
||||
}
|
||||
|
||||
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
|
||||
|
||||
vk::PushConstantRange pcr(
|
||||
|
|
@ -4360,10 +4423,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||
|
||||
if (device->float_controls_rte_fp16 &&
|
||||
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
|
||||
if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
|
@ -4388,43 +4450,28 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_rte_len, cpy_f32_q1_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
|
||||
#define SET_ROWS(itype, rte) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## rte ## _len, set_rows_q1_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
#define SET_ROWS(itype) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
SET_ROWS(_i32, _rte)
|
||||
SET_ROWS(_i64, _rte)
|
||||
} else {
|
||||
SET_ROWS(_i32, )
|
||||
SET_ROWS(_i64, )
|
||||
}
|
||||
SET_ROWS(_i32)
|
||||
SET_ROWS(_i64)
|
||||
#undef SET_ROWS
|
||||
|
||||
|
||||
|
|
@ -4444,11 +4491,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
return s;
|
||||
};
|
||||
|
||||
bool rte = device->float_controls_rte_fp16;
|
||||
#define CREATE_BINARY(name, namemod, spec, bindings) \
|
||||
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
|
||||
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
|
||||
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
|
||||
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
|
||||
|
||||
CREATE_BINARY(add, , {0}, 4)
|
||||
|
|
@ -4491,13 +4537,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
|
@ -4538,19 +4579,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
CREATE_UNARY(floor)
|
||||
CREATE_UNARY(trunc)
|
||||
CREATE_UNARY(sgn)
|
||||
CREATE_UNARY(exp)
|
||||
#undef CREATE_UNARY
|
||||
|
||||
#define CREATE_UNARY_RTE(name) \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
}
|
||||
CREATE_UNARY_RTE(exp)
|
||||
#undef CREATE_UNARY_RTE
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
|
@ -4560,13 +4591,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
#define CREATE_GLU(name) \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
|
||||
|
||||
CREATE_GLU(geglu)
|
||||
CREATE_GLU(reglu)
|
||||
|
|
@ -4599,25 +4625,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
||||
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
||||
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
|
||||
|
|
@ -4679,13 +4694,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
#define IM2COL(bda) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
if (device->shader_int64 && device->buffer_device_address) {
|
||||
IM2COL(_bda)
|
||||
} else {
|
||||
|
|
@ -14381,8 +14391,7 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
|
|||
}
|
||||
|
||||
// conditions for pipeline creation
|
||||
if (!(ctx->device->float_controls_rte_fp16 &&
|
||||
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
|
||||
if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#if defined(SET_ROWS) && QUANT_K == 1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "utils.glsl"
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
#include "rope_params.glsl"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#include "rte.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "utils.glsl"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "rope_params.glsl"
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
#if !defined(GGML_ROPE_PARAMS)
|
||||
#define GGML_ROPE_PARAMS
|
||||
|
||||
#include "rte.glsl"
|
||||
|
||||
struct rope_params {
|
||||
uint rope_mode;
|
||||
uint nrows;
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
|
||||
#if RTE16
|
||||
#extension GL_EXT_spirv_intrinsics : enable
|
||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||
#endif // RTE16
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
|
|
|
|||
|
|
@ -762,7 +762,7 @@ void process_shaders() {
|
|||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
|
||||
string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
|
||||
string_to_spv("rms_norm_mul_rope_f32_f16", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
|
|
@ -786,15 +786,12 @@ void process_shaders() {
|
|||
|
||||
for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}
|
||||
|
||||
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}
|
||||
|
||||
auto get_type_str = [](bool f16) {
|
||||
|
|
@ -811,12 +808,10 @@ void process_shaders() {
|
|||
for (auto src0_f16 : {false, true}) {
|
||||
for (auto src1_f16 : {false, true}) {
|
||||
for (auto dst_f16 : {false, true}) {
|
||||
for (auto rte : {false, true}) {
|
||||
auto source = op == "add_rms" ? std::string("add") : op;
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
|
||||
auto add_rms = op == "add_rms" ? "1" : "0";
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -864,14 +859,11 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
}
|
||||
string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
|
|
@ -925,21 +917,18 @@ void process_shaders() {
|
|||
string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
}
|
||||
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("swiglu_oai_f16", "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("swiglu_oai_f32", "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
|
@ -959,25 +948,18 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
||||
|
|
@ -1000,7 +982,6 @@ void process_shaders() {
|
|||
std::string bda_def = bda ? "1" : "0";
|
||||
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1053,8 +1034,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
|
|
@ -1110,8 +1091,8 @@ void write_output_files() {
|
|||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
|
||||
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
|
||||
hdr << "extern const void * " << op << "_data[2][2][2];\n";
|
||||
hdr << "extern const uint64_t " << op << "_len[2][2][2];\n";
|
||||
|
||||
std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp";
|
||||
// if (basename(input_filepath) != op_file) {
|
||||
|
|
@ -1119,8 +1100,8 @@ void write_output_files() {
|
|||
// }
|
||||
std::stringstream data = make_generic_stringstream();
|
||||
std::stringstream len = make_generic_stringstream();
|
||||
data << "const void * " << op << "_data[2][2][2][2] = ";
|
||||
len << "const uint64_t " << op << "_len[2][2][2][2] = ";
|
||||
data << "const void * " << op << "_data[2][2][2] = ";
|
||||
len << "const uint64_t " << op << "_len[2][2][2] = ";
|
||||
for (uint32_t t0 = 0; t0 < 2; ++t0) {
|
||||
if (t0 == 0) {
|
||||
data << "{";
|
||||
|
|
@ -1136,20 +1117,10 @@ void write_output_files() {
|
|||
data << "{";
|
||||
len << "{";
|
||||
}
|
||||
for (uint32_t rte = 0; rte < 2; ++rte) {
|
||||
if (rte == 0) {
|
||||
data << "{";
|
||||
len << "{";
|
||||
}
|
||||
data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
|
||||
len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
|
||||
data << "_data,";
|
||||
len << "_len,";
|
||||
if (rte == 1) {
|
||||
data << "}, ";
|
||||
len << "}, ";
|
||||
}
|
||||
}
|
||||
data << op << suffixes[t0] << suffixes[t1] << suffixes[t2];
|
||||
len << op << suffixes[t0] << suffixes[t1] << suffixes[t2];
|
||||
data << "_data,";
|
||||
len << "_len,";
|
||||
if (t2 == 1) {
|
||||
data << "}, ";
|
||||
len << "}, ";
|
||||
|
|
|
|||
|
|
@ -18,9 +18,6 @@
|
|||
#include "ggml.h"
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
// TODO: tmp until the ggml meta backend matures and becomes public
|
||||
#include "../src/ggml-ext.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
|
|
|
|||
|
|
@ -31,9 +31,6 @@ static bool old_mixtral_warning_showed = false;
|
|||
#include "ggml-backend.h"
|
||||
#include "gguf.h"
|
||||
|
||||
// TODO: tmp until the ggml meta backend matures and becomes public
|
||||
#include "../src/ggml-ext.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
|
|
|
|||
|
|
@ -114,6 +114,13 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
|
|||
return n_pos;
|
||||
}
|
||||
|
||||
void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, mtmd_decoder_pos * out_pos) {
|
||||
size_t n_tokens = mtmd_image_tokens_get_n_tokens(chunks);
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, i);
|
||||
}
|
||||
}
|
||||
|
||||
// helper struct to make working with embd batch easier
|
||||
// note: this will be removed after llama_batch_ext refactoring
|
||||
struct decode_embd_batch {
|
||||
|
|
@ -156,18 +163,15 @@ struct decode_embd_batch {
|
|||
}
|
||||
|
||||
// M-RoPE for image
|
||||
void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
|
||||
void set_position_mrope_2d(llama_pos pos_0, const std::vector<mtmd_decoder_pos> & rel_pos, llama_seq_id seq_id) {
|
||||
GGML_ASSERT(n_pos_per_embd == 4);
|
||||
GGML_ASSERT(nx > 0 && ny > 0 && nx * ny == batch.n_tokens);
|
||||
GGML_ASSERT(!rel_pos.empty() && (int32_t)rel_pos.size() == batch.n_tokens);
|
||||
seq_id_0[0] = seq_id;
|
||||
for (int y = 0; y < ny; y++) {
|
||||
for (int x = 0; x < nx; x++) {
|
||||
int i = y * nx + x;
|
||||
pos[i ] = pos_0;
|
||||
pos[i + batch.n_tokens ] = pos_0 + y;
|
||||
pos[i + batch.n_tokens * 2] = pos_0 + x;
|
||||
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
|
||||
}
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
pos[i ] = pos_0 + rel_pos[i].t;
|
||||
pos[i + batch.n_tokens ] = pos_0 + rel_pos[i].y;
|
||||
pos[i + batch.n_tokens * 2] = pos_0 + rel_pos[i].x;
|
||||
pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused
|
||||
}
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
batch.n_seq_id[i] = 1;
|
||||
|
|
@ -262,9 +266,10 @@ int32_t mtmd_helper_decode_image_chunk(
|
|||
LOG_ERR("failed to decode chunk: image tokens are null\n");
|
||||
return -1;
|
||||
}
|
||||
const int nx = mtmd_image_tokens_get_nx(image_tokens);
|
||||
const int ny = mtmd_image_tokens_get_ny(image_tokens);
|
||||
batch_embd.set_position_mrope_2d(n_past, nx, ny, seq_id);
|
||||
const auto n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
|
||||
std::vector<mtmd_decoder_pos> rel_pos(n_tokens);
|
||||
mtmd_helper_image_get_decoder_pos(image_tokens, rel_pos.data());
|
||||
batch_embd.set_position_mrope_2d(n_past, rel_pos, seq_id);
|
||||
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
batch_embd.set_position_mrope_1d(n_past, seq_id);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
|
|||
// normally, n_pos is equal to n_tokens, but for M-RoPE it is different
|
||||
MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks);
|
||||
|
||||
// helper to get the list of relative positions corresponding to the embedding tokens, to be used by M-RoPE
|
||||
// out_pos must have length == mtmd_helper_get_n_tokens(image)
|
||||
MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, mtmd_decoder_pos * out_pos);
|
||||
|
||||
// helper function that automatically:
|
||||
// 1. run llama_decode() on text chunks
|
||||
// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
|
||||
|
|
|
|||
|
|
@ -1249,6 +1249,14 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
|
|||
return image_tokens->ny;
|
||||
}
|
||||
|
||||
mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i) {
|
||||
mtmd_decoder_pos pos;
|
||||
pos.t = 0;
|
||||
pos.x = i % image_tokens->nx;
|
||||
pos.y = i / image_tokens->nx;
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
|
||||
return image_tokens->id.c_str();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -186,12 +186,25 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
|
|||
// the instance will be constructed via mtmd_tokenize()
|
||||
// it will be freed along with mtmd_input_chunk
|
||||
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
|
||||
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||
|
||||
DEPRECATED(MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens),
|
||||
"use mtmd_image_tokens_get_decoder_pos() instead");
|
||||
DEPRECATED(MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens),
|
||||
"use mtmd_image_tokens_get_decoder_pos() instead");
|
||||
|
||||
struct mtmd_decoder_pos {
|
||||
uint32_t t;
|
||||
uint32_t x;
|
||||
uint32_t y;
|
||||
};
|
||||
// get position for decoder attention, to be used by M-RoPE models
|
||||
// i is the index of the embedding token, ranging from 0 to mtmd_image_tokens_get_n_tokens() - 1
|
||||
// return relative position (for example, embedding 0 will have position (0, 0, 0); remember to adjust it to the current absolute position)
|
||||
MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i);
|
||||
|
||||
// tokenize an input text prompt and a list of bitmaps (images/audio)
|
||||
// the prompt must have the input image marker (default: "<__media__>") in it
|
||||
// the default marker is defined by mtmd_default_marker()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue