mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # src/llama-context.cpp # tests/test-backend-ops.cpp
This commit is contained in:
commit
fb13e3e51b
26 changed files with 543 additions and 447 deletions
|
@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
|
|
||||||
class string_view {
|
|
||||||
const std::string & _str;
|
|
||||||
const size_t _start;
|
|
||||||
const size_t _end;
|
|
||||||
public:
|
|
||||||
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
|
|
||||||
|
|
||||||
size_t size() const {
|
|
||||||
return _end - _start;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t length() const {
|
|
||||||
return size();
|
|
||||||
}
|
|
||||||
|
|
||||||
operator std::string() const {
|
|
||||||
return str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string str() const {
|
|
||||||
return _str.substr(_start, _end - _start);
|
|
||||||
}
|
|
||||||
|
|
||||||
string_view substr(size_t pos, size_t len = std::string::npos) const {
|
|
||||||
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
|
|
||||||
}
|
|
||||||
|
|
||||||
char operator[](size_t pos) const {
|
|
||||||
auto index = _start + pos;
|
|
||||||
if (index >= _end) {
|
|
||||||
throw std::out_of_range("string_view index out of range");
|
|
||||||
}
|
|
||||||
return _str[_start + pos];
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator==(const string_view & other) const {
|
|
||||||
std::string this_str = *this;
|
|
||||||
std::string other_str = other;
|
|
||||||
return this_str == other_str;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
auto has_min = min_value != std::numeric_limits<int>::min();
|
||||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
auto has_max = max_value != std::numeric_limits<int>::max();
|
||||||
|
@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||||
}
|
}
|
||||||
out << "}";
|
out << "}";
|
||||||
};
|
};
|
||||||
std::function<void(const string_view &, const string_view &)> uniform_range =
|
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
|
||||||
[&](const string_view & from, const string_view & to) {
|
[&](const std::string_view & from, const std::string_view & to) {
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
while (i < from.length() && i < to.length() && from[i] == to[i]) {
|
while (i < from.length() && i < to.length() && from[i] == to[i]) {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
out << "\"" << from.substr(0, i).str() << "\"";
|
out << "\"" << from.substr(0, i) << "\"";
|
||||||
}
|
}
|
||||||
if (i < from.length() && i < to.length()) {
|
if (i < from.length() && i < to.length()) {
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
|
|
|
@ -2193,7 +2193,7 @@ class Llama4VisionModel(MmprojModel):
|
||||||
name += ".weight"
|
name += ".weight"
|
||||||
if "multi_modal_projector.linear_1" in name:
|
if "multi_modal_projector.linear_1" in name:
|
||||||
# despite the name with number postfix, this is a single fully connected layer
|
# despite the name with number postfix, this is a single fully connected layer
|
||||||
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)]
|
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)]
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
|
@ -245,8 +245,18 @@ static bool fp16_mma_available(const int cc) {
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||||
return false;
|
return false;
|
||||||
#else
|
#else
|
||||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
||||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
|
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||||
|
return true;
|
||||||
|
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||||
|
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -366,6 +376,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
||||||
|
template<bool norm>
|
||||||
|
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
||||||
|
const int row = blockIdx.x;
|
||||||
|
const int col = threadIdx.x;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int i = col; i < ncols; i += blockDim.x) {
|
||||||
|
sum += x[row * ncols + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (col != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[row] = norm ? sum / ncols : sum;
|
||||||
|
}
|
||||||
|
|
||||||
template<int width = WARP_SIZE>
|
template<int width = WARP_SIZE>
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
|
@ -39,6 +39,7 @@ bool g_mul_mat_q = true;
|
||||||
#include "ggml-cuda/ssm-scan.cuh"
|
#include "ggml-cuda/ssm-scan.cuh"
|
||||||
#include "ggml-cuda/sum.cuh"
|
#include "ggml-cuda/sum.cuh"
|
||||||
#include "ggml-cuda/sumrows.cuh"
|
#include "ggml-cuda/sumrows.cuh"
|
||||||
|
#include "ggml-cuda/mean.cuh"
|
||||||
#include "ggml-cuda/tsembd.cuh"
|
#include "ggml-cuda/tsembd.cuh"
|
||||||
#include "ggml-cuda/unary.cuh"
|
#include "ggml-cuda/unary.cuh"
|
||||||
#include "ggml-cuda/upscale.cuh"
|
#include "ggml-cuda/upscale.cuh"
|
||||||
|
@ -101,8 +102,7 @@ int ggml_cuda_get_device() {
|
||||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||||
ggml_cuda_set_device(device);
|
ggml_cuda_set_device(device);
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
|
||||||
{
|
|
||||||
err = cudaMallocManaged(ptr, size);
|
err = cudaMallocManaged(ptr, size);
|
||||||
#if defined(GGML_USE_HIP)
|
#if defined(GGML_USE_HIP)
|
||||||
if (err == hipSuccess) {
|
if (err == hipSuccess) {
|
||||||
|
@ -120,9 +120,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
||||||
err = cudaMalloc(ptr, size);
|
err = cudaMalloc(ptr, size);
|
||||||
}
|
}
|
||||||
#endif // defined(GGML_USE_HIP)
|
#endif // defined(GGML_USE_HIP)
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
err = cudaMalloc(ptr, size);
|
err = cudaMalloc(ptr, size);
|
||||||
}
|
}
|
||||||
return err;
|
return err;
|
||||||
|
@ -2362,6 +2360,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
ggml_cuda_op_sum_rows(ctx, dst);
|
ggml_cuda_op_sum_rows(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_MEAN:
|
||||||
|
ggml_cuda_op_mean(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
ggml_cuda_op_ssm_conv(ctx, dst);
|
ggml_cuda_op_ssm_conv(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -3265,6 +3266,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
return true;
|
return true;
|
||||||
|
|
19
ggml/src/ggml-cuda/mean.cu
Normal file
19
ggml/src/ggml-cuda/mean.cu
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
#include "mean.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const float * src0_d = (const float *) src0->data;
|
||||||
|
float * dst_d = (float *) dst->data;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
const int64_t ncols = src0->ne[0];
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
|
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||||
|
}
|
3
ggml/src/ggml-cuda/mean.cuh
Normal file
3
ggml/src/ggml-cuda/mean.cuh
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,25 +1,9 @@
|
||||||
#include "sumrows.cuh"
|
#include "sumrows.cuh"
|
||||||
|
|
||||||
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
|
||||||
const int row = blockIdx.x;
|
|
||||||
const int col = threadIdx.x;
|
|
||||||
|
|
||||||
float sum = 0.0f;
|
|
||||||
for (int i = col; i < ncols; i += blockDim.x) {
|
|
||||||
sum += x[row * ncols + i];
|
|
||||||
}
|
|
||||||
|
|
||||||
sum = warp_reduce_sum(sum);
|
|
||||||
|
|
||||||
if (col == 0) {
|
|
||||||
dst[row] = sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
const dim3 block_nums(nrows, 1, 1);
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const int64_t ncols = src0->ne[0];
|
const int64_t ncols = src0->ne[0];
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
|
|
||||||
|
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||||
|
|
||||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
|
||||||
int mtl_device_ref_count;
|
int mtl_device_ref_count;
|
||||||
id<MTLLibrary> mtl_library;
|
id<MTLLibrary> mtl_library;
|
||||||
|
|
||||||
|
NSLock * mtl_lock;
|
||||||
|
|
||||||
bool has_simdgroup_reduction;
|
bool has_simdgroup_reduction;
|
||||||
bool has_simdgroup_mm;
|
bool has_simdgroup_mm;
|
||||||
bool has_residency_sets;
|
bool has_residency_sets;
|
||||||
bool has_bfloat;
|
bool has_bfloat;
|
||||||
bool use_bfloat;
|
bool use_bfloat;
|
||||||
|
|
||||||
|
size_t max_size;
|
||||||
|
|
||||||
char name[128];
|
char name[128];
|
||||||
} g_ggml_ctx_dev_main = {
|
} g_ggml_ctx_dev_main = {
|
||||||
/*.mtl_device =*/ nil,
|
/*.mtl_device =*/ nil,
|
||||||
/*.mtl_device_ref_count =*/ 0,
|
/*.mtl_device_ref_count =*/ 0,
|
||||||
/*.mtl_library =*/ nil,
|
/*.mtl_library =*/ nil,
|
||||||
|
/*.mtl_lock =*/ nil,
|
||||||
/*.has_simdgroup_reduction =*/ false,
|
/*.has_simdgroup_reduction =*/ false,
|
||||||
/*.has_simdgroup_mm =*/ false,
|
/*.has_simdgroup_mm =*/ false,
|
||||||
/*.has_residency_sets =*/ false,
|
/*.has_residency_sets =*/ false,
|
||||||
/*.has_bfloat =*/ false,
|
/*.has_bfloat =*/ false,
|
||||||
/*.use_bfloat =*/ false,
|
/*.use_bfloat =*/ false,
|
||||||
|
/*.max_size =*/ 0,
|
||||||
/*.name =*/ "",
|
/*.name =*/ "",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
|
||||||
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
||||||
assert(ctx != NULL);
|
assert(ctx != NULL);
|
||||||
|
|
||||||
|
if (ctx->mtl_lock == nil) {
|
||||||
|
ctx->mtl_lock = [[NSLock alloc] init];
|
||||||
|
}
|
||||||
|
|
||||||
if (ctx->mtl_device == nil) {
|
if (ctx->mtl_device == nil) {
|
||||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||||
}
|
}
|
||||||
|
@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||||
ctx->use_bfloat = false;
|
ctx->use_bfloat = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||||
|
|
||||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||||
ctx->mtl_device_ref_count--;
|
ctx->mtl_device_ref_count--;
|
||||||
|
|
||||||
if (ctx->mtl_device_ref_count == 0) {
|
if (ctx->mtl_device_ref_count == 0) {
|
||||||
|
if (ctx->mtl_lock) {
|
||||||
|
[ctx->mtl_lock release];
|
||||||
|
ctx->mtl_lock = nil;
|
||||||
|
}
|
||||||
|
|
||||||
if (ctx->mtl_library) {
|
if (ctx->mtl_library) {
|
||||||
[ctx->mtl_library release];
|
[ctx->mtl_library release];
|
||||||
ctx->mtl_library = nil;
|
ctx->mtl_library = nil;
|
||||||
|
@ -977,7 +994,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
||||||
|
|
||||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||||
|
|
||||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
||||||
|
|
||||||
|
@ -991,9 +1008,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
if (ctx_dev->mtl_library == nil) {
|
{
|
||||||
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
[ctx_dev->mtl_lock lock];
|
||||||
|
|
||||||
|
if (ctx_dev->mtl_library == nil) {
|
||||||
|
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
||||||
|
}
|
||||||
|
|
||||||
|
[ctx_dev->mtl_lock unlock];
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
||||||
if (metal_library == nil) {
|
if (metal_library == nil) {
|
||||||
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
||||||
|
@ -5284,7 +5308,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_metal_buffer_rset_free(ctx);
|
ggml_backend_metal_buffer_rset_free(ctx);
|
||||||
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
|
||||||
|
|
||||||
if (ctx->owned) {
|
if (ctx->owned) {
|
||||||
#if TARGET_OS_OSX
|
#if TARGET_OS_OSX
|
||||||
|
@ -5393,7 +5416,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
||||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
||||||
|
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
||||||
|
|
||||||
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||||
|
|
||||||
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
||||||
ctx->all_size = size_aligned;
|
ctx->all_size = size_aligned;
|
||||||
|
@ -5416,14 +5442,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||||
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
||||||
free(ctx);
|
free(ctx);
|
||||||
ggml_backend_metal_device_rel(ctx_dev);
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||||
free(ctx);
|
free(ctx);
|
||||||
ggml_backend_metal_device_rel(ctx_dev);
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5434,17 +5458,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||||
|
|
||||||
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||||
return 32;
|
return 32;
|
||||||
|
|
||||||
GGML_UNUSED(buft);
|
GGML_UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||||
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
|
||||||
const size_t max_size = device.maxBufferLength;
|
|
||||||
ggml_backend_metal_device_rel(buft->device->context);
|
|
||||||
|
|
||||||
return max_size;
|
return max_size;
|
||||||
|
|
||||||
GGML_UNUSED(buft);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||||
|
@ -5517,7 +5538,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
||||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
||||||
|
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
||||||
|
|
||||||
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||||
|
|
||||||
// the buffer fits into the max buffer size allowed by the device
|
// the buffer fits into the max buffer size allowed by the device
|
||||||
if (size_aligned <= device.maxBufferLength) {
|
if (size_aligned <= device.maxBufferLength) {
|
||||||
|
@ -5573,7 +5597,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
||||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||||
free(ctx);
|
free(ctx);
|
||||||
ggml_backend_metal_device_rel(ctx_dev);
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5589,10 +5612,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
||||||
struct ggml_backend_metal_context * ctx = backend->context;
|
struct ggml_backend_metal_context * ctx = backend->context;
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
||||||
|
|
||||||
ggml_backend_metal_device_rel(ctx_dev);
|
|
||||||
ggml_metal_free(ctx);
|
ggml_metal_free(ctx);
|
||||||
|
|
||||||
free(backend);
|
free(backend);
|
||||||
|
@ -5732,6 +5753,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
||||||
|
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||||
|
|
||||||
|
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
||||||
|
|
||||||
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5751,10 +5774,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
||||||
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
||||||
ggml_backend_metal_device_acq(ctx_dev);
|
|
||||||
ggml_backend_metal_device_rel(ctx_dev);
|
|
||||||
|
|
||||||
return ctx_dev->name;
|
return ctx_dev->name;
|
||||||
}
|
}
|
||||||
|
@ -5762,12 +5782,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
|
||||||
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
||||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||||
|
|
||||||
*total = device.recommendedMaxWorkingSetSize;
|
*total = device.recommendedMaxWorkingSetSize;
|
||||||
*free = *total - device.currentAllocatedSize;
|
*free = *total - device.currentAllocatedSize;
|
||||||
|
|
||||||
ggml_backend_metal_device_rel(ctx_dev);
|
|
||||||
} else {
|
} else {
|
||||||
*free = 1;
|
*free = 1;
|
||||||
*total = 1;
|
*total = 1;
|
||||||
|
@ -5845,7 +5863,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
||||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
|
||||||
|
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
||||||
|
|
||||||
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||||
|
|
||||||
// the buffer fits into the max buffer size allowed by the device
|
// the buffer fits into the max buffer size allowed by the device
|
||||||
if (size_aligned <= device.maxBufferLength) {
|
if (size_aligned <= device.maxBufferLength) {
|
||||||
|
@ -5901,7 +5922,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
||||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||||
free(ctx);
|
free(ctx);
|
||||||
ggml_backend_metal_device_rel(ctx_dev);
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5915,8 +5935,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||||
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
return
|
||||||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
||||||
|
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
@ -6001,8 +6022,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
||||||
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// called upon program exit
|
||||||
|
static void ggml_metal_cleanup(void) {
|
||||||
|
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: make thread-safe
|
||||||
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
||||||
// TODO: make this thread-safe somehow?
|
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
||||||
|
|
||||||
|
// register cleanup callback
|
||||||
|
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
|
||||||
|
atexit(ggml_metal_cleanup);
|
||||||
|
|
||||||
{
|
{
|
||||||
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
||||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||||
|
|
|
@ -1057,6 +1057,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
|
||||||
struct vk_instance_t {
|
struct vk_instance_t {
|
||||||
vk::Instance instance;
|
vk::Instance instance;
|
||||||
|
|
||||||
|
bool debug_utils_support = false; // VK_EXT_debug_utils enabled
|
||||||
|
PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
|
||||||
|
PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
|
||||||
|
PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
|
||||||
|
PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
|
||||||
|
PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
|
||||||
|
PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
|
||||||
|
|
||||||
std::vector<size_t> device_indices;
|
std::vector<size_t> device_indices;
|
||||||
vk_device devices[GGML_VK_MAX_DEVICES];
|
vk_device devices[GGML_VK_MAX_DEVICES];
|
||||||
};
|
};
|
||||||
|
@ -1196,6 +1204,14 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
||||||
}
|
}
|
||||||
pipeline->compiled = true;
|
pipeline->compiled = true;
|
||||||
|
|
||||||
|
if (vk_instance.debug_utils_support) {
|
||||||
|
vk::DebugUtilsObjectNameInfoEXT duoni;
|
||||||
|
duoni.objectType = vk::ObjectType::ePipeline;
|
||||||
|
duoni.pObjectName = pipeline->name.c_str();
|
||||||
|
duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
|
||||||
|
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> guard(device->mutex);
|
std::lock_guard<std::mutex> guard(device->mutex);
|
||||||
device->pipelines.insert({ pipeline->name, pipeline });
|
device->pipelines.insert({ pipeline->name, pipeline });
|
||||||
|
@ -3585,6 +3601,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||||
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||||
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||||
|
|
||||||
|
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
||||||
|
|
||||||
static void ggml_vk_instance_init() {
|
static void ggml_vk_instance_init() {
|
||||||
if (vk_instance_initialized) {
|
if (vk_instance_initialized) {
|
||||||
return;
|
return;
|
||||||
|
@ -3605,7 +3623,7 @@ static void ggml_vk_instance_init() {
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
|
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
|
||||||
#endif
|
#endif
|
||||||
|
const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
|
||||||
std::vector<const char*> layers;
|
std::vector<const char*> layers;
|
||||||
|
|
||||||
if (validation_ext) {
|
if (validation_ext) {
|
||||||
|
@ -3620,6 +3638,9 @@ static void ggml_vk_instance_init() {
|
||||||
extensions.push_back("VK_KHR_portability_enumeration");
|
extensions.push_back("VK_KHR_portability_enumeration");
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
if (debug_utils_ext) {
|
||||||
|
extensions.push_back("VK_EXT_debug_utils");
|
||||||
|
}
|
||||||
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
|
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
if (portability_enumeration_ext) {
|
if (portability_enumeration_ext) {
|
||||||
|
@ -3643,6 +3664,18 @@ static void ggml_vk_instance_init() {
|
||||||
vk_instance.instance = vk::createInstance(instance_create_info);
|
vk_instance.instance = vk::createInstance(instance_create_info);
|
||||||
vk_instance_initialized = true;
|
vk_instance_initialized = true;
|
||||||
|
|
||||||
|
if (debug_utils_ext) {
|
||||||
|
vk_instance.debug_utils_support = true;
|
||||||
|
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
|
||||||
|
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
|
||||||
|
vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
|
||||||
|
vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
|
||||||
|
vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
|
||||||
|
vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
||||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||||
|
|
||||||
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
||||||
|
@ -9680,6 +9713,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||||
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||||
|
|
||||||
|
if (vk_instance.debug_utils_support) {
|
||||||
|
vk::DebugUtilsLabelEXT dul = {};
|
||||||
|
dul.pLabelName = "ggml_backend_vk_graph_compute";
|
||||||
|
dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
|
||||||
|
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
|
||||||
|
}
|
||||||
|
|
||||||
uint64_t total_mat_mul_bytes = 0;
|
uint64_t total_mat_mul_bytes = 0;
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
|
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
|
||||||
|
@ -10369,6 +10409,22 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
||||||
UNUSED(instance_extensions);
|
UNUSED(instance_extensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extension availability
|
||||||
|
static bool ggml_vk_instance_debug_utils_ext_available(
|
||||||
|
const std::vector<vk::ExtensionProperties> & instance_extensions) {
|
||||||
|
// Check for portability enumeration extension for MoltenVK support
|
||||||
|
for (const auto & properties : instance_extensions) {
|
||||||
|
if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
|
||||||
|
return false;
|
||||||
|
|
||||||
|
UNUSED(instance_extensions);
|
||||||
|
}
|
||||||
|
|
||||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||||
switch (props.vendorID) {
|
switch (props.vendorID) {
|
||||||
case VK_VENDOR_ID_INTEL:
|
case VK_VENDOR_ID_INTEL:
|
||||||
|
|
|
@ -197,13 +197,23 @@ class SpecialVocab:
|
||||||
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
||||||
if not tokenizer_config:
|
if not tokenizer_config:
|
||||||
special_eos = special_last
|
special_eos = special_last
|
||||||
|
elif special_last != special_eos:
|
||||||
|
if 'eot' not in self.special_token_types:
|
||||||
|
self.special_token_types = tuple(self.special_token_types) + ('eot', )
|
||||||
|
tokenizer_config['eot_token'] = special_eos
|
||||||
|
elif 'eom' not in self.special_token_types:
|
||||||
|
self.special_token_types = tuple(self.special_token_types) + ('eom', )
|
||||||
|
tokenizer_config['eom_token'] = special_eos
|
||||||
|
else:
|
||||||
|
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
|
||||||
|
tokenizer_config['eos_token'] = special_eos = special_last
|
||||||
self.add_special_token['eos'] = True if special_last == special_eos else False
|
self.add_special_token['eos'] = True if special_last == special_eos else False
|
||||||
if special_last != special_eos:
|
if special_last != special_eos:
|
||||||
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
||||||
if tmpl_pair:
|
if tmpl_pair:
|
||||||
seq_start = 1 if tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
||||||
seq_stop = -1 if tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
||||||
if seq_start == 0 or seq_stop is None:
|
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
|
||||||
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
||||||
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
||||||
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
||||||
|
|
|
@ -280,8 +280,8 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// simulate full KV cache
|
// simulate full KV cache
|
||||||
|
|
||||||
const auto mstate = memory->init_full();
|
const auto mctx = memory->init_full();
|
||||||
if (!mstate) {
|
if (!mctx) {
|
||||||
throw std::runtime_error("failed to initialize KV cache");
|
throw std::runtime_error("failed to initialize KV cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,7 +289,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve pp graph first so that buffers are only allocated once
|
// reserve pp graph first so that buffers are only allocated once
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
|
@ -300,7 +300,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve with tg graph to get the number of splits and nodes
|
// reserve with tg graph to get the number of splits and nodes
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||||
}
|
}
|
||||||
|
@ -311,7 +311,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
|
@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||||
optimize |= memory_force_optimize;
|
optimize |= memory_force_optimize;
|
||||||
memory_force_optimize = false;
|
memory_force_optimize = false;
|
||||||
|
|
||||||
const auto mstate = memory->init_update(this, optimize);
|
const auto mctx = memory->init_update(this, optimize);
|
||||||
switch (mstate->get_status()) {
|
switch (mctx->get_status()) {
|
||||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
{
|
{
|
||||||
// noop
|
// noop
|
||||||
|
@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!mstate->apply()) {
|
if (!mctx->apply()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the memory module did any computation, we have to reserve a new worst-case graph
|
// if the memory module did any computation, we have to reserve a new worst-case graph
|
||||||
{
|
{
|
||||||
const auto mstate = memory->init_full();
|
const auto mctx = memory->init_full();
|
||||||
if (!mstate) {
|
if (!mctx) {
|
||||||
throw std::runtime_error("failed to initialize memory state");
|
throw std::runtime_error("failed to initialize memory context");
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_seqs = cparams.n_seq_max;
|
const uint32_t n_seqs = cparams.n_seq_max;
|
||||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||||
}
|
}
|
||||||
|
@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
|
||||||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
|
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||||
if (mstate && !mstate->apply()) {
|
if (mctx && !mctx->apply()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||||
ret = GGML_STATUS_FAILED;
|
ret = GGML_STATUS_FAILED;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
||||||
if (!res) {
|
if (!res) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
||||||
ret = GGML_STATUS_FAILED;
|
ret = GGML_STATUS_FAILED;
|
||||||
|
@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
// handle any pending defrags/shifts
|
// handle any pending defrags/shifts
|
||||||
kv_self_update(false);
|
kv_self_update(false);
|
||||||
|
|
||||||
llama_memory_state_ptr mstate;
|
llama_memory_context_ptr mctx;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||||
if (!mstate) {
|
if (!mctx) {
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (mstate->get_status()) {
|
switch (mctx->get_status()) {
|
||||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
{
|
{
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
{
|
{
|
||||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
||||||
|
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
|
@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
int64_t n_outputs_prev = 0;
|
int64_t n_outputs_prev = 0;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
const auto & ubatch = mstate->get_ubatch();
|
const auto & ubatch = mctx->get_ubatch();
|
||||||
|
|
||||||
// count the outputs in this ubatch
|
// count the outputs in this ubatch
|
||||||
{
|
{
|
||||||
|
@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
|
|
||||||
ggml_status status;
|
ggml_status status;
|
||||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||||
|
@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
}
|
}
|
||||||
|
|
||||||
n_outputs_prev += n_outputs;
|
n_outputs_prev += n_outputs;
|
||||||
} while (mstate->next());
|
} while (mctx->next());
|
||||||
|
|
||||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||||
n_outputs = n_outputs_all;
|
n_outputs = n_outputs_all;
|
||||||
|
@ -1292,8 +1292,8 @@ ggml_cgraph * llama_context::graph_init() {
|
||||||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
|
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
||||||
//LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
// LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||||
|
|
||||||
if (n_tokens % n_seqs != 0) {
|
if (n_tokens % n_seqs != 0) {
|
||||||
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
||||||
|
@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
||||||
|
|
||||||
this->n_outputs = save_n_outputs;
|
this->n_outputs = save_n_outputs;
|
||||||
|
|
||||||
|
@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_result_ptr llama_context::graph_build(
|
llm_graph_result_ptr llama_context::graph_build(
|
||||||
ggml_context * ctx,
|
ggml_context * ctx,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
llm_graph_type gtype,
|
llm_graph_type gtype,
|
||||||
const llama_memory_state_i * mstate) {
|
const llama_memory_context_i * mctx) {
|
||||||
return model.build_graph(
|
return model.build_graph(
|
||||||
{
|
{
|
||||||
/*.ctx =*/ ctx,
|
/*.ctx =*/ ctx,
|
||||||
|
@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
||||||
/*.backend_cpu =*/ backend_cpu,
|
/*.backend_cpu =*/ backend_cpu,
|
||||||
/*.cvec =*/ &cvec,
|
/*.cvec =*/ &cvec,
|
||||||
/*.loras =*/ &loras,
|
/*.loras =*/ &loras,
|
||||||
/*.mstate =*/ mstate,
|
/*.mctx =*/ mctx,
|
||||||
/*.cross =*/ &cross,
|
/*.cross =*/ &cross,
|
||||||
/*.n_outputs =*/ n_outputs,
|
/*.n_outputs =*/ n_outputs,
|
||||||
/*.cb =*/ graph_get_cb(),
|
/*.cb =*/ graph_get_cb(),
|
||||||
|
@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter(
|
||||||
|
|
||||||
uint32_t n_outputs_all = n_tokens_all;
|
uint32_t n_outputs_all = n_tokens_all;
|
||||||
|
|
||||||
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
||||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter(
|
||||||
|
|
||||||
uint32_t pos_batch = 0;
|
uint32_t pos_batch = 0;
|
||||||
do {
|
do {
|
||||||
const auto & ubatch = mstate->get_ubatch();
|
const auto & ubatch = mctx->get_ubatch();
|
||||||
|
|
||||||
n_outputs = ubatch.n_tokens;
|
n_outputs = ubatch.n_tokens;
|
||||||
|
|
||||||
if (!mstate->apply()) {
|
if (!mctx->apply()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
||||||
|
|
||||||
struct ggml_context * ctx_compute_opt;
|
struct ggml_context * ctx_compute_opt;
|
||||||
{
|
{
|
||||||
|
@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter(
|
||||||
ggml_free(ctx_compute_opt);
|
ggml_free(ctx_compute_opt);
|
||||||
|
|
||||||
pos_batch += ubatch.n_tokens;
|
pos_batch += ubatch.n_tokens;
|
||||||
} while (mstate->next());
|
} while (mctx->next());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ class llama_io_read_i;
|
||||||
class llama_io_write_i;
|
class llama_io_write_i;
|
||||||
|
|
||||||
struct llama_memory_i;
|
struct llama_memory_i;
|
||||||
struct llama_memory_state_i;
|
struct llama_memory_context_i;
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context {
|
||||||
// init scheduler and compute buffers, reserve worst-case graphs
|
// init scheduler and compute buffers, reserve worst-case graphs
|
||||||
|
@ -93,14 +93,14 @@ struct llama_context {
|
||||||
int32_t il_end);
|
int32_t il_end);
|
||||||
|
|
||||||
// process a single ubatch with a specific graph type
|
// process a single ubatch with a specific graph type
|
||||||
// if memory_state is provided, it will be applied first to the context's memory
|
// if memory_context is provided, it will be applied first to the context's memory
|
||||||
// ret contains the status of the graph computation
|
// ret contains the status of the graph computation
|
||||||
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
||||||
llm_graph_result_ptr process_ubatch(
|
llm_graph_result_ptr process_ubatch(
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
llm_graph_type gtype,
|
llm_graph_type gtype,
|
||||||
llama_memory_state_i * mstate,
|
llama_memory_context_i * mctx,
|
||||||
ggml_status & ret);
|
ggml_status & ret);
|
||||||
|
|
||||||
int encode(const llama_batch & batch_inp);
|
int encode(const llama_batch & batch_inp);
|
||||||
int decode(const llama_batch & batch_inp);
|
int decode(const llama_batch & batch_inp);
|
||||||
|
@ -197,15 +197,15 @@ public:
|
||||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||||
|
|
||||||
// reserve a graph with a dummy ubatch of the specified size
|
// reserve a graph with a dummy ubatch of the specified size
|
||||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
|
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llm_graph_result_ptr graph_build(
|
llm_graph_result_ptr graph_build(
|
||||||
ggml_context * ctx,
|
ggml_context * ctx,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
llm_graph_type gtype,
|
llm_graph_type gtype,
|
||||||
const llama_memory_state_i * mstate);
|
const llama_memory_context_i * mctx);
|
||||||
|
|
||||||
llm_graph_cb graph_get_cb() const;
|
llm_graph_cb graph_get_cb() const;
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
||||||
if (pos_bucket) {
|
if (pos_bucket) {
|
||||||
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
|
mctx->set_input_pos_bucket(pos_bucket, ubatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||||
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_UNUSED(ubatch);
|
GGML_UNUSED(ubatch);
|
||||||
|
|
||||||
const int64_t n_rs = mem_state->get_n_rs();
|
const int64_t n_rs = mctx->get_n_rs();
|
||||||
|
|
||||||
if (s_copy) {
|
if (s_copy) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||||
|
@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||||
data[i] = mem_state->s_copy(i);
|
data[i] = mctx->s_copy(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (self_kq_mask_swa) {
|
if (self_kq_mask_swa) {
|
||||||
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
|
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||||
|
|
||||||
if (s_copy) {
|
if (s_copy) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||||
|
@ -345,7 +345,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
|
|
||||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||||
data[i] = mem_state->get_state_recr()->s_copy(i);
|
data[i] = mctx->get_recr()->s_copy(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -389,7 +389,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||||
backend_cpu (params.backend_cpu),
|
backend_cpu (params.backend_cpu),
|
||||||
cvec (params.cvec),
|
cvec (params.cvec),
|
||||||
loras (params.loras),
|
loras (params.loras),
|
||||||
mstate (params.mstate),
|
mctx (params.mctx),
|
||||||
cross (params.cross),
|
cross (params.cross),
|
||||||
cb_func (params.cb),
|
cb_func (params.cb),
|
||||||
res (std::make_unique<llm_graph_result>()) {
|
res (std::make_unique<llm_graph_result>()) {
|
||||||
|
@ -950,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
|
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
|
|
||||||
auto & cur = inp->pos_bucket;
|
auto & cur = inp->pos_bucket;
|
||||||
|
|
||||||
|
@ -982,14 +982,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
|
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
{
|
{
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||||
|
|
||||||
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
|
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
@ -999,7 +999,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
|
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
|
||||||
|
|
||||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||||
ggml_set_input(inp->s_copy);
|
ggml_set_input(inp->s_copy);
|
||||||
|
@ -1183,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
{
|
{
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
@ -1220,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
ggml_tensor * q = q_cur;
|
||||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -1270,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||||
|
|
||||||
const bool is_swa = hparams.is_swa(il);
|
const bool is_swa = hparams.is_swa(il);
|
||||||
|
|
||||||
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
|
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
ggml_tensor * q = q_cur;
|
||||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -1379,19 +1379,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
ggml_tensor * q = q_cur;
|
||||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
|
@ -1412,12 +1412,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
@ -1429,7 +1429,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||||
{
|
{
|
||||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||||
|
@ -1485,11 +1485,11 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
||||||
|
|
||||||
const auto n_rs = kv_state->get_n_rs();
|
const auto n_rs = mctx_cur->get_n_rs();
|
||||||
|
|
||||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||||
ggml_set_input(inp->s_copy);
|
ggml_set_input(inp->s_copy);
|
||||||
|
@ -1504,9 +1504,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies) const {
|
bool avoid_copies) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rs(
|
ggml_tensor * llm_graph_context::build_rs(
|
||||||
|
@ -1516,9 +1516,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies) const {
|
bool avoid_copies) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
||||||
|
|
||||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||||
|
@ -1526,13 +1526,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto token_shift_count = hparams.token_shift_count;
|
const auto token_shift_count = hparams.token_shift_count;
|
||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
|
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
||||||
|
|
||||||
ggml_tensor * token_shift = build_rs(
|
ggml_tensor * token_shift = build_rs(
|
||||||
inp, gf, token_shift_all,
|
inp, gf, token_shift_all,
|
||||||
|
@ -1547,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||||
ggml_tensor * token_shift,
|
ggml_tensor * token_shift,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto token_shift_count = hparams.token_shift_count;
|
const auto token_shift_count = hparams.token_shift_count;
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
return ggml_cpy(
|
return ggml_cpy(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||||
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
|
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,12 +17,12 @@ struct ggml_tensor;
|
||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
struct llama_cparams;
|
struct llama_cparams;
|
||||||
|
|
||||||
struct llama_memory_state_i;
|
struct llama_memory_context_i;
|
||||||
|
|
||||||
class llama_kv_cache_unified_state;
|
class llama_kv_cache_unified_context;
|
||||||
class llama_kv_cache_unified_iswa_state;
|
class llama_kv_cache_unified_iswa_context;
|
||||||
class llama_memory_recurrent_state;
|
class llama_memory_recurrent_context;
|
||||||
class llama_memory_hybrid_state;
|
class llama_memory_hybrid_context;
|
||||||
|
|
||||||
// certain models (typically multi-modal) can produce different types of graphs
|
// certain models (typically multi-modal) can produce different types of graphs
|
||||||
enum llm_graph_type {
|
enum llm_graph_type {
|
||||||
|
@ -136,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_pos_bucket_kv(
|
llm_graph_input_pos_bucket_kv(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
@ -144,7 +144,8 @@ public:
|
||||||
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_kv_cache_unified_state * kv_state;
|
|
||||||
|
const llama_kv_cache_unified_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||||
|
@ -191,14 +192,14 @@ public:
|
||||||
|
|
||||||
class llm_graph_input_rs : public llm_graph_input_i {
|
class llm_graph_input_rs : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
|
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
|
||||||
virtual ~llm_graph_input_rs() = default;
|
virtual ~llm_graph_input_rs() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * s_copy; // I32 [kv_size]
|
ggml_tensor * s_copy; // I32 [kv_size]
|
||||||
|
|
||||||
const llama_memory_recurrent_state * mem_state;
|
const llama_memory_recurrent_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||||
|
@ -238,10 +239,10 @@ public:
|
||||||
llm_graph_input_attn_kv_unified(
|
llm_graph_input_attn_kv_unified(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
const llama_kv_cache_unified_state * kv_state) :
|
const llama_kv_cache_unified_context * mctx) :
|
||||||
hparams(hparams),
|
hparams(hparams),
|
||||||
cparams(cparams),
|
cparams(cparams),
|
||||||
kv_state(kv_state) {
|
mctx(mctx) {
|
||||||
}
|
}
|
||||||
~llm_graph_input_attn_kv_unified() = default;
|
~llm_graph_input_attn_kv_unified() = default;
|
||||||
|
|
||||||
|
@ -255,7 +256,7 @@ public:
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * kv_state;
|
const llama_kv_cache_unified_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||||
|
@ -263,10 +264,10 @@ public:
|
||||||
llm_graph_input_attn_kv_unified_iswa(
|
llm_graph_input_attn_kv_unified_iswa(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
const llama_kv_cache_unified_iswa_state * kv_state) :
|
const llama_kv_cache_unified_iswa_context * mctx) :
|
||||||
hparams(hparams),
|
hparams(hparams),
|
||||||
cparams(cparams),
|
cparams(cparams),
|
||||||
kv_state(kv_state) {
|
mctx(mctx) {
|
||||||
}
|
}
|
||||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||||
|
|
||||||
|
@ -283,7 +284,7 @@ public:
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
const llama_kv_cache_unified_iswa_state * kv_state;
|
const llama_kv_cache_unified_iswa_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||||
|
@ -306,10 +307,10 @@ public:
|
||||||
llm_graph_input_mem_hybrid(
|
llm_graph_input_mem_hybrid(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
const llama_memory_hybrid_state * mem_state) :
|
const llama_memory_hybrid_context * mctx) :
|
||||||
hparams(hparams),
|
hparams(hparams),
|
||||||
cparams(cparams),
|
cparams(cparams),
|
||||||
mem_state(mem_state) {
|
mctx(mctx) {
|
||||||
}
|
}
|
||||||
virtual ~llm_graph_input_mem_hybrid() = default;
|
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||||
|
|
||||||
|
@ -325,7 +326,7 @@ public:
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
const llama_memory_hybrid_state * mem_state;
|
const llama_memory_hybrid_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -401,10 +402,10 @@ struct llm_graph_params {
|
||||||
ggml_backend_sched_t sched;
|
ggml_backend_sched_t sched;
|
||||||
ggml_backend_t backend_cpu;
|
ggml_backend_t backend_cpu;
|
||||||
|
|
||||||
const llama_adapter_cvec * cvec;
|
const llama_adapter_cvec * cvec;
|
||||||
const llama_adapter_loras * loras;
|
const llama_adapter_loras * loras;
|
||||||
const llama_memory_state_i * mstate;
|
const llama_memory_context_i * mctx;
|
||||||
const llama_cross * cross;
|
const llama_cross * cross;
|
||||||
|
|
||||||
uint32_t n_outputs;
|
uint32_t n_outputs;
|
||||||
|
|
||||||
|
@ -453,10 +454,10 @@ struct llm_graph_context {
|
||||||
|
|
||||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||||
|
|
||||||
const llama_adapter_cvec * cvec;
|
const llama_adapter_cvec * cvec;
|
||||||
const llama_adapter_loras * loras;
|
const llama_adapter_loras * loras;
|
||||||
const llama_memory_state_i * mstate;
|
const llama_memory_context_i * mctx;
|
||||||
const llama_cross * cross;
|
const llama_cross * cross;
|
||||||
|
|
||||||
const llm_graph_cb & cb_func;
|
const llm_graph_cb & cb_func;
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
return kv_swa->seq_pos_max(seq_id);
|
return kv_swa->seq_pos_max(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||||
GGML_UNUSED(embd_all);
|
GGML_UNUSED(embd_all);
|
||||||
|
|
||||||
// first try simple split
|
// first try simple split
|
||||||
|
@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(heads_base.size() == heads_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
|
@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(heads_base.size() == heads_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
// TODO: if we fail again, we should attempt different splitting strategies
|
// TODO: if we fail again, we should attempt different splitting strategies
|
||||||
// but to do that properly, we first have to refactor the batches to be more flexible
|
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||||
|
@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_iswa_state
|
// llama_kv_cache_unified_iswa_context
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv) :
|
llama_kv_cache_unified_iswa * kv) :
|
||||||
state_base(kv->get_base()->init_full()),
|
ctx_base(kv->get_base()->init_full()),
|
||||||
state_swa (kv->get_swa ()->init_full()),
|
ctx_swa (kv->get_swa ()->init_full()),
|
||||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize) :
|
bool optimize) :
|
||||||
state_base(kv->get_base()->init_update(lctx, optimize)),
|
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
||||||
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
||||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
std::vector<uint32_t> heads_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
std::vector<uint32_t> heads_swa,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
|
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
||||||
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
||||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
|
||||||
|
|
||||||
bool llama_kv_cache_unified_iswa_state::next() {
|
bool llama_kv_cache_unified_iswa_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
state_base->next();
|
ctx_base->next();
|
||||||
state_swa ->next();
|
ctx_swa ->next();
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified_iswa_state::apply() {
|
bool llama_kv_cache_unified_iswa_context::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
bool res = true;
|
bool res = true;
|
||||||
|
|
||||||
res = res & state_base->apply();
|
res = res & ctx_base->apply();
|
||||||
res = res & state_swa ->apply();
|
res = res & ctx_swa ->apply();
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,14 +31,14 @@ public:
|
||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) override;
|
bool embd_all) override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
@ -72,32 +72,32 @@ private:
|
||||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache context
|
||||||
llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv);
|
llama_kv_cache_unified_iswa * kv);
|
||||||
|
|
||||||
// used to create an update state
|
// used to create an update context
|
||||||
llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize);
|
bool optimize);
|
||||||
|
|
||||||
// used to create a state from a batch
|
// used to create a batch processing context from a batch
|
||||||
llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
std::vector<uint32_t> heads_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
std::vector<uint32_t> heads_swa,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_iswa_state();
|
virtual ~llama_kv_cache_unified_iswa_context();
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_state_i
|
// llama_memory_context_i
|
||||||
//
|
//
|
||||||
|
|
||||||
bool next() override;
|
bool next() override;
|
||||||
|
@ -107,11 +107,11 @@ public:
|
||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_iswa_state specific API
|
// llama_kv_cache_unified_iswa_context specific API
|
||||||
//
|
//
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * get_base() const;
|
const llama_kv_cache_unified_context * get_base() const;
|
||||||
const llama_kv_cache_unified_state * get_swa() const;
|
const llama_kv_cache_unified_context * get_swa() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//llama_kv_cache_unified_iswa * kv;
|
//llama_kv_cache_unified_iswa * kv;
|
||||||
|
@ -121,8 +121,8 @@ private:
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
const llama_memory_state_ptr state_base;
|
const llama_memory_context_ptr ctx_base;
|
||||||
const llama_memory_state_ptr state_swa;
|
const llama_memory_context_ptr ctx_swa;
|
||||||
|
|
||||||
const llama_memory_status status;
|
const llama_memory_status status;
|
||||||
};
|
};
|
||||||
|
|
|
@ -307,7 +307,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
return cells.seq_pos_max(seq_id);
|
return cells.seq_pos_max(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) {
|
bool embd_all) {
|
||||||
|
@ -332,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(
|
return std::make_unique<llama_kv_cache_unified_context>(
|
||||||
this, std::move(heads), std::move(ubatches));
|
this, std::move(heads), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(this);
|
return std::make_unique<llama_kv_cache_unified_context>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||||
bool do_shift = get_has_shift();
|
bool do_shift = get_has_shift();
|
||||||
|
|
||||||
defrag_info dinfo;
|
defrag_info dinfo;
|
||||||
|
@ -373,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
|
@ -1710,18 +1710,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_state
|
// llama_kv_cache_unified_context
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||||
n_kv = kv->get_size();
|
n_kv = kv->get_size();
|
||||||
head = 0;
|
head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool do_shift,
|
bool do_shift,
|
||||||
|
@ -1731,15 +1731,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_kv_cache_unified::ubatch_heads heads,
|
llama_kv_cache_unified::ubatch_heads heads,
|
||||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||||
|
|
||||||
bool llama_kv_cache_unified_state::next() {
|
bool llama_kv_cache_unified_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
|
@ -1749,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified_state::apply() {
|
bool llama_kv_cache_unified_context::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
// no ubatches -> this is a KV cache update
|
// no ubatches -> this is a KV cache update
|
||||||
|
@ -1767,45 +1767,45 @@ bool llama_kv_cache_unified_state::apply() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_kv_cache_unified_state::get_status() const {
|
llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
|
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified_state::get_n_kv() const {
|
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
||||||
return n_kv;
|
return n_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||||
return kv->get_k(ctx, il, n_kv);
|
return kv->get_k(ctx, il, n_kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
||||||
return kv->get_v(ctx, il, n_kv);
|
return kv->get_v(ctx, il, n_kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
||||||
return kv->cpy_k(ctx, k_cur, il, head);
|
return kv->cpy_k(ctx, k_cur, il, head);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
||||||
return kv->cpy_v(ctx, v_cur, il, head);
|
return kv->cpy_v(ctx, v_cur, il, head);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
|
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||||
kv->set_input_k_shift(dst);
|
kv->set_input_k_shift(dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
kv->set_input_pos_bucket(dst, ubatch);
|
kv->set_input_pos_bucket(dst, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,14 +56,14 @@ public:
|
||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) override;
|
bool embd_all) override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
@ -208,36 +208,36 @@ private:
|
||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// some shorthands
|
// some shorthands
|
||||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_state(llama_memory_status status);
|
llama_kv_cache_unified_context(llama_memory_status status);
|
||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache context
|
||||||
llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv);
|
llama_kv_cache_unified * kv);
|
||||||
|
|
||||||
// used to create an update state
|
// used to create an update context
|
||||||
llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool do_shift,
|
bool do_shift,
|
||||||
defrag_info dinfo);
|
defrag_info dinfo);
|
||||||
|
|
||||||
// used to create a decode state from a batch
|
// used to create a batch procesing context from a batch
|
||||||
llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
ubatch_heads heads,
|
ubatch_heads heads,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_state();
|
virtual ~llama_kv_cache_unified_context();
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_state_i
|
// llama_memory_context_i
|
||||||
//
|
//
|
||||||
|
|
||||||
bool next() override;
|
bool next() override;
|
||||||
|
@ -247,7 +247,7 @@ public:
|
||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_state specific API
|
// llama_kv_cache_unified_context specific API
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t get_n_kv() const;
|
uint32_t get_n_kv() const;
|
||||||
|
@ -272,7 +272,7 @@ private:
|
||||||
llama_context * lctx;
|
llama_context * lctx;
|
||||||
|
|
||||||
//
|
//
|
||||||
// update state
|
// update context
|
||||||
//
|
//
|
||||||
|
|
||||||
bool do_shift = false;
|
bool do_shift = false;
|
||||||
|
@ -280,7 +280,7 @@ private:
|
||||||
defrag_info dinfo;
|
defrag_info dinfo;
|
||||||
|
|
||||||
//
|
//
|
||||||
// batch processing state
|
// batch processing context
|
||||||
//
|
//
|
||||||
|
|
||||||
// the index of the next ubatch to process
|
// the index of the next ubatch to process
|
||||||
|
|
|
@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||||
n_seq_max
|
n_seq_max
|
||||||
)) {}
|
)) {}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||||
do {
|
do {
|
||||||
balloc.split_reset();
|
balloc.split_reset();
|
||||||
|
|
||||||
|
@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball
|
||||||
|
|
||||||
// prepare the recurrent batches first
|
// prepare the recurrent batches first
|
||||||
if (!mem_recr->prepare(ubatches)) {
|
if (!mem_recr->prepare(ubatches)) {
|
||||||
// TODO: will the recurrent cache be in an undefined state at this point?
|
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||||
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||||
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare the attention cache
|
// prepare the attention cache
|
||||||
auto heads_attn = mem_attn->prepare(ubatches);
|
auto heads_attn = mem_attn->prepare(ubatches);
|
||||||
if (heads_attn.empty()) {
|
if (heads_attn.empty()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
||||||
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_memory_hybrid_state>(
|
return std::make_unique<llama_memory_hybrid_context>(
|
||||||
this, std::move(heads_attn), std::move(ubatches));
|
this, std::move(heads_attn), std::move(ubatches));
|
||||||
} while(false);
|
} while(false);
|
||||||
|
|
||||||
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_hybrid::init_full() {
|
llama_memory_context_ptr llama_memory_hybrid::init_full() {
|
||||||
return std::make_unique<llama_memory_hybrid_state>(this);
|
return std::make_unique<llama_memory_hybrid_context>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
||||||
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
|
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_hybrid::get_can_shift() const {
|
bool llama_memory_hybrid::get_can_shift() const {
|
||||||
|
@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
||||||
return mem_recr.get();
|
return mem_recr.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
|
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
|
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
|
||||||
state_attn(mem->get_mem_attn()->init_full()),
|
ctx_attn(mem->get_mem_attn()->init_full()),
|
||||||
state_recr(mem->get_mem_recr()->init_full()),
|
ctx_recr(mem->get_mem_recr()->init_full()),
|
||||||
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize) :
|
bool optimize) :
|
||||||
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||||
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||||
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
std::vector<uint32_t> heads_attn,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
||||||
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)),
|
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||||
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_hybrid_state::next() {
|
bool llama_memory_hybrid_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
state_attn->next();
|
ctx_attn->next();
|
||||||
state_recr->next();
|
ctx_recr->next();
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_hybrid_state::apply() {
|
bool llama_memory_hybrid_context::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
bool res = true;
|
bool res = true;
|
||||||
|
|
||||||
res = res & state_attn->apply();
|
res = res & ctx_attn->apply();
|
||||||
res = res & state_recr->apply();
|
res = res & ctx_recr->apply();
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_memory_hybrid_state::get_status() const {
|
llama_memory_status llama_memory_hybrid_context::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
|
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
|
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
||||||
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
|
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
|
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
||||||
return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
|
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,14 +49,14 @@ public:
|
||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) override;
|
bool embd_all) override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
@ -90,27 +90,27 @@ private:
|
||||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_memory_hybrid_state : public llama_memory_state_i {
|
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// init failure
|
// init failure
|
||||||
explicit llama_memory_hybrid_state(llama_memory_status status);
|
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||||
|
|
||||||
// init full
|
// init full
|
||||||
explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
|
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
|
||||||
|
|
||||||
// init update
|
// init update
|
||||||
explicit llama_memory_hybrid_state(
|
explicit llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize);
|
bool optimize);
|
||||||
|
|
||||||
// init success
|
// init success
|
||||||
llama_memory_hybrid_state(
|
llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
std::vector<uint32_t> heads_attn,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
~llama_memory_hybrid_state() = default;
|
~llama_memory_hybrid_context() = default;
|
||||||
|
|
||||||
bool next() override;
|
bool next() override;
|
||||||
bool apply() override;
|
bool apply() override;
|
||||||
|
@ -119,11 +119,11 @@ public:
|
||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_hybrid_state
|
// llama_memory_hybrid_context
|
||||||
//
|
//
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * get_state_attn() const;
|
const llama_kv_cache_unified_context * get_attn() const;
|
||||||
const llama_memory_recurrent_state * get_state_recr() const;
|
const llama_memory_recurrent_context * get_recr() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// the index of the next ubatch to process
|
// the index of the next ubatch to process
|
||||||
|
@ -131,8 +131,8 @@ private:
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
const llama_memory_state_ptr state_attn;
|
const llama_memory_context_ptr ctx_attn;
|
||||||
const llama_memory_state_ptr state_recr;
|
const llama_memory_context_ptr ctx_recr;
|
||||||
|
|
||||||
const llama_memory_status status;
|
const llama_memory_status status;
|
||||||
};
|
};
|
||||||
|
|
|
@ -362,7 +362,7 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -383,21 +383,21 @@ llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & b
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!prepare(ubatches)) {
|
if (!prepare(ubatches)) {
|
||||||
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
|
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_recurrent::init_full() {
|
llama_memory_context_ptr llama_memory_recurrent::init_full() {
|
||||||
return std::make_unique<llama_memory_recurrent_state>(this);
|
return std::make_unique<llama_memory_recurrent_context>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||||
GGML_UNUSED(lctx);
|
GGML_UNUSED(lctx);
|
||||||
GGML_UNUSED(optimize);
|
GGML_UNUSED(optimize);
|
||||||
|
|
||||||
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
|
@ -1040,22 +1040,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_recurrent_state
|
// llama_memory_recurrent_context
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
|
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
||||||
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
||||||
llama_memory_recurrent * mem,
|
llama_memory_recurrent * mem,
|
||||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
||||||
|
|
||||||
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
|
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
|
||||||
|
|
||||||
bool llama_memory_recurrent_state::next() {
|
bool llama_memory_recurrent_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
|
@ -1065,7 +1065,7 @@ bool llama_memory_recurrent_state::next() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_recurrent_state::apply() {
|
bool llama_memory_recurrent_context::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
mem->find_slot(ubatches[i_next]);
|
mem->find_slot(ubatches[i_next]);
|
||||||
|
@ -1073,40 +1073,40 @@ bool llama_memory_recurrent_state::apply() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_memory_recurrent_state::get_status() const {
|
llama_memory_status llama_memory_recurrent_context::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
|
const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_memory_recurrent_state::get_n_rs() const {
|
uint32_t llama_memory_recurrent_context::get_n_rs() const {
|
||||||
return is_full ? mem->size : mem->n;
|
return is_full ? mem->size : mem->n;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_memory_recurrent_state::get_head() const {
|
uint32_t llama_memory_recurrent_context::get_head() const {
|
||||||
return is_full ? 0 : mem->head;
|
return is_full ? 0 : mem->head;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_memory_recurrent_state::get_rs_z() const {
|
int32_t llama_memory_recurrent_context::get_rs_z() const {
|
||||||
return is_full ? 0 : mem->rs_z;
|
return is_full ? 0 : mem->rs_z;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_memory_recurrent_state::get_size() const {
|
uint32_t llama_memory_recurrent_context::get_size() const {
|
||||||
return mem->size;
|
return mem->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
|
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
|
||||||
return mem->r_l[il];
|
return mem->r_l[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
|
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
||||||
return mem->s_l[il];
|
return mem->s_l[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_memory_recurrent_state::s_copy(int i) const {
|
int32_t llama_memory_recurrent_context::s_copy(int i) const {
|
||||||
return mem->cells[i + mem->head].src0;
|
return mem->cells[i + mem->head].src0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,8 @@
|
||||||
// llama_memory_recurrent
|
// llama_memory_recurrent
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
|
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
|
||||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
|
||||||
class llama_memory_recurrent : public llama_memory_i {
|
class llama_memory_recurrent : public llama_memory_i {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -34,14 +34,14 @@ public:
|
||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) override;
|
bool embd_all) override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
void clear(bool data) override;
|
void clear(bool data) override;
|
||||||
|
|
||||||
|
@ -125,24 +125,24 @@ private:
|
||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_memory_recurrent_state : public llama_memory_state_i {
|
class llama_memory_recurrent_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_memory_recurrent_state(llama_memory_status status);
|
llama_memory_recurrent_context(llama_memory_status status);
|
||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache or update context
|
||||||
llama_memory_recurrent_state(
|
llama_memory_recurrent_context(
|
||||||
llama_memory_recurrent * mem);
|
llama_memory_recurrent * mem);
|
||||||
|
|
||||||
// used to create a state from a batch
|
// used to create a batch processing context from a batch
|
||||||
llama_memory_recurrent_state(
|
llama_memory_recurrent_context(
|
||||||
llama_memory_recurrent * mem,
|
llama_memory_recurrent * mem,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_memory_recurrent_state();
|
virtual ~llama_memory_recurrent_context();
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_state_i
|
// llama_memory_context_i
|
||||||
//
|
//
|
||||||
|
|
||||||
bool next() override;
|
bool next() override;
|
||||||
|
@ -152,7 +152,7 @@ public:
|
||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_recurrent_state specific API
|
// llama_memory_recurrent_context specific API
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t get_n_rs() const;
|
uint32_t get_n_rs() const;
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
|
|
||||||
|
@ -28,23 +27,21 @@ enum llama_memory_status {
|
||||||
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
||||||
};
|
};
|
||||||
|
|
||||||
// helper function for combining the status of two memory states
|
// helper function for combining the status of two memory contexts
|
||||||
// useful for implementing hybrid memory types (e.g. iSWA)
|
// useful for implementing hybrid memory types (e.g. iSWA)
|
||||||
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
||||||
|
|
||||||
// the interface for managing the memory state during batch processing
|
// the interface for managing the memory context during batch processing
|
||||||
// this interface is implemented per memory type. see:
|
// this interface is implemented per memory type. see:
|
||||||
// - llama_kv_cache_unified_state
|
// - llama_kv_cache_unified_context
|
||||||
// - llama_kv_cache_unified_iswa_state
|
// - llama_kv_cache_unified_iswa_context
|
||||||
// ...
|
// ...
|
||||||
//
|
//
|
||||||
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
|
||||||
//
|
struct llama_memory_context_i {
|
||||||
// TODO: rename to llama_memory_context_i ?
|
virtual ~llama_memory_context_i() = default;
|
||||||
struct llama_memory_state_i {
|
|
||||||
virtual ~llama_memory_state_i() = default;
|
|
||||||
|
|
||||||
// consume the current ubatch from the state and proceed to the next one
|
// consume the current ubatch from the context and proceed to the next one
|
||||||
// return false if we are done
|
// return false if we are done
|
||||||
virtual bool next() = 0;
|
virtual bool next() = 0;
|
||||||
|
|
||||||
|
@ -55,11 +52,11 @@ struct llama_memory_state_i {
|
||||||
// get the current ubatch
|
// get the current ubatch
|
||||||
virtual const llama_ubatch & get_ubatch() const = 0;
|
virtual const llama_ubatch & get_ubatch() const = 0;
|
||||||
|
|
||||||
// get the status of the memory state - used for error handling and checking if any updates would be applied
|
// get the status of the memory context - used for error handling and checking if any updates would be applied
|
||||||
virtual llama_memory_status get_status() const = 0;
|
virtual llama_memory_status get_status() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
|
||||||
|
|
||||||
// general concept of LLM memory
|
// general concept of LLM memory
|
||||||
// the KV cache is a type of LLM memory, but there can be other types
|
// the KV cache is a type of LLM memory, but there can be other types
|
||||||
|
@ -67,19 +64,19 @@ struct llama_memory_i {
|
||||||
virtual ~llama_memory_i() = default;
|
virtual ~llama_memory_i() = default;
|
||||||
|
|
||||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||||
// return a state object containing the ubatches and KV cache state required to process them
|
// return a context object containing the ubatches and memory state required to process them
|
||||||
// check the llama_memory_state_i::get_status() for the result
|
// check the llama_memory_context_i::get_status() for the result
|
||||||
virtual llama_memory_state_ptr init_batch(
|
virtual llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) = 0;
|
bool embd_all) = 0;
|
||||||
|
|
||||||
// simulate full cache, used for allocating worst-case compute buffers
|
// simulate full cache, used for allocating worst-case compute buffers
|
||||||
virtual llama_memory_state_ptr init_full() = 0;
|
virtual llama_memory_context_ptr init_full() = 0;
|
||||||
|
|
||||||
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||||
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||||
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||||
|
|
||||||
// getters
|
// getters
|
||||||
virtual bool get_can_shift() const = 0;
|
virtual bool get_can_shift() const = 0;
|
||||||
|
|
|
@ -9271,9 +9271,9 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
const int64_t d_conv = hparams.ssm_d_conv;
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
const int64_t d_inner = hparams.ssm_d_inner;
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
|
@ -9291,8 +9291,8 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
GGML_ASSERT(ubatch.equal_seqs);
|
GGML_ASSERT(ubatch.equal_seqs);
|
||||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||||
|
|
||||||
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||||
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||||
|
|
||||||
// (ab)using the KV cache to store the states
|
// (ab)using the KV cache to store the states
|
||||||
ggml_tensor * conv = build_rs(
|
ggml_tensor * conv = build_rs(
|
||||||
|
@ -12016,7 +12016,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||||
ggml_tensor * x_prev,
|
ggml_tensor * x_prev,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto n_tokens = ubatch.n_tokens;
|
const auto n_tokens = ubatch.n_tokens;
|
||||||
const auto n_seqs = ubatch.n_seqs;
|
const auto n_seqs = ubatch.n_seqs;
|
||||||
|
@ -12026,7 +12026,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||||
const auto n_head = n_embd / head_size;
|
const auto n_head = n_embd / head_size;
|
||||||
const auto n_head_kv = hparams.n_head_kv(il);
|
const auto n_head_kv = hparams.n_head_kv(il);
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
const auto & layer = model.layers[il];
|
const auto & layer = model.layers[il];
|
||||||
|
|
||||||
|
@ -12138,7 +12138,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * wkv_state = build_rs(
|
ggml_tensor * wkv_state = build_rs(
|
||||||
inp, gf, kv_state->get_s_l(il),
|
inp, gf, mctx_cur->get_s_l(il),
|
||||||
hparams.n_embd_s(), n_seqs);
|
hparams.n_embd_s(), n_seqs);
|
||||||
|
|
||||||
ggml_tensor * wkv_output;
|
ggml_tensor * wkv_output;
|
||||||
|
@ -12157,9 +12157,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||||
wkv_state,
|
wkv_state,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
kv_state->get_s_l(il),
|
mctx_cur->get_s_l(il),
|
||||||
hparams.n_embd_s() * n_seqs,
|
hparams.n_embd_s() * n_seqs,
|
||||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
@ -12413,7 +12413,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||||
ggml_tensor *& first_layer_value,
|
ggml_tensor *& first_layer_value,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto n_tokens = ubatch.n_tokens;
|
const auto n_tokens = ubatch.n_tokens;
|
||||||
const auto n_seqs = ubatch.n_seqs;
|
const auto n_seqs = ubatch.n_seqs;
|
||||||
|
@ -12422,7 +12422,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||||
const auto head_count = n_embd / head_size;
|
const auto head_count = n_embd / head_size;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
const auto & layer = model.layers[il];
|
const auto & layer = model.layers[il];
|
||||||
|
|
||||||
|
@ -12493,7 +12493,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||||
|
|
||||||
ggml_tensor * wkv_state = build_rs(
|
ggml_tensor * wkv_state = build_rs(
|
||||||
inp, gf, kv_state->get_s_l(il),
|
inp, gf, mctx_cur->get_s_l(il),
|
||||||
hparams.n_embd_s(), n_seqs);
|
hparams.n_embd_s(), n_seqs);
|
||||||
|
|
||||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||||
|
@ -12507,9 +12507,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||||
wkv_state,
|
wkv_state,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
kv_state->get_s_l(il),
|
mctx_cur->get_s_l(il),
|
||||||
hparams.n_embd_s() * n_seqs,
|
hparams.n_embd_s() * n_seqs,
|
||||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
|
@ -2257,6 +2257,9 @@ struct clip_model_loader {
|
||||||
{
|
{
|
||||||
hparams.rope_theta = 10000.0f;
|
hparams.rope_theta = 10000.0f;
|
||||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||||
|
// Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/issues/14310
|
||||||
|
hparams.image_size = 1024;
|
||||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_GEMMA3:
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue