diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index d38a74f95..637891f50 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } -/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */ -class string_view { - const std::string & _str; - const size_t _start; - const size_t _end; -public: - string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {} - - size_t size() const { - return _end - _start; - } - - size_t length() const { - return size(); - } - - operator std::string() const { - return str(); - } - - std::string str() const { - return _str.substr(_start, _end - _start); - } - - string_view substr(size_t pos, size_t len = std::string::npos) const { - return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len); - } - - char operator[](size_t pos) const { - auto index = _start + pos; - if (index >= _end) { - throw std::out_of_range("string_view index out of range"); - } - return _str[_start + pos]; - } - - bool operator==(const string_view & other) const { - std::string this_str = *this; - std::string other_str = other; - return this_str == other_str; - } -}; - static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { auto has_min = min_value != std::numeric_limits::min(); auto has_max = max_value != std::numeric_limits::max(); @@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & } out << "}"; }; - std::function uniform_range = - [&](const string_view & from, const string_view & to) { + std::function uniform_range = + [&](const std::string_view & from, const std::string_view & to) { size_t i = 0; while (i < from.length() && i < to.length() && from[i] == to[i]) { i++; } if (i > 0) { - out << "\"" << from.substr(0, i).str() << "\""; + out << "\"" << from.substr(0, i) << "\""; } if (i < from.length() && i < to.length()) { if (i > 0) { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2fe76589e..bbf8b30ff 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2193,7 +2193,7 @@ class Llama4VisionModel(MmprojModel): name += ".weight" if "multi_modal_projector.linear_1" in name: # despite the name with number postfix, this is a single fully connected layer - return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)] + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)] return [(self.map_tensor_name(name), data_torch)] return [] diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e18f1f2db..d90731e1e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -245,8 +245,18 @@ static bool fp16_mma_available(const int cc) { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) return false; #else - return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); + if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { + return true; + } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) + } else { + return false; + } #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } @@ -366,6 +376,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) +template +static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum); + + if (col != 0) { + return; + } + + dst[row] = norm ? sum / ncols : sum; +} + template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b05c40a0e..495ae8e16 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -39,6 +39,7 @@ bool g_mul_mat_q = true; #include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" @@ -101,8 +102,7 @@ int ggml_cuda_get_device() { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; - if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) - { + if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) { err = cudaMallocManaged(ptr, size); #if defined(GGML_USE_HIP) if (err == hipSuccess) { @@ -120,9 +120,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) err = cudaMalloc(ptr, size); } #endif // defined(GGML_USE_HIP) - } - else - { + } else { err = cudaMalloc(ptr, size); } return err; @@ -2362,6 +2360,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; + case GGML_OP_MEAN: + ggml_cuda_op_mean(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_cuda_op_ssm_conv(ctx, dst); break; @@ -3265,6 +3266,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu new file mode 100644 index 000000000..4b238a399 --- /dev/null +++ b/ggml/src/ggml-cuda/mean.cu @@ -0,0 +1,19 @@ +#include "mean.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + reduce_rows_f32<<>>(src0_d, dst_d, ncols); +} diff --git a/ggml/src/ggml-cuda/mean.cuh b/ggml/src/ggml-cuda/mean.cuh new file mode 100644 index 000000000..2b9b10433 --- /dev/null +++ b/ggml/src/ggml-cuda/mean.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 38dbf1b5e..2eee08fa0 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -1,25 +1,9 @@ #include "sumrows.cuh" -static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.x; - const int col = threadIdx.x; - - float sum = 0.0f; - for (int i = col; i < ncols; i += blockDim.x) { - sum += x[row * ncols + i]; - } - - sum = warp_reduce_sum(sum); - - if (col == 0) { - dst[row] = sum; - } -} - void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); - k_sum_rows_f32<<>>(x, dst, ncols); + reduce_rows_f32<<>>(x, dst, ncols); } void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream); + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + reduce_rows_f32<<>>(src0_d, dst_d, ncols); } diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index 191db1c13..3431c599b 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -1,5 +1,4 @@ #include "common.cuh" void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); - void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 1ef3aaac0..c9f6e91df 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context { int mtl_device_ref_count; id mtl_library; + NSLock * mtl_lock; + bool has_simdgroup_reduction; bool has_simdgroup_mm; bool has_residency_sets; bool has_bfloat; bool use_bfloat; + size_t max_size; + char name[128]; } g_ggml_ctx_dev_main = { /*.mtl_device =*/ nil, /*.mtl_device_ref_count =*/ 0, /*.mtl_library =*/ nil, + /*.mtl_lock =*/ nil, /*.has_simdgroup_reduction =*/ false, /*.has_simdgroup_mm =*/ false, /*.has_residency_sets =*/ false, /*.has_bfloat =*/ false, /*.use_bfloat =*/ false, + /*.max_size =*/ 0, /*.name =*/ "", }; @@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context { static id ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) { assert(ctx != NULL); + if (ctx->mtl_lock == nil) { + ctx->mtl_lock = [[NSLock alloc] init]; + } + if (ctx->mtl_device == nil) { ctx->mtl_device = MTLCreateSystemDefaultDevice(); } @@ -94,6 +104,8 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev ctx->use_bfloat = false; #endif + ctx->max_size = ctx->mtl_device.maxBufferLength; + strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); } @@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte ctx->mtl_device_ref_count--; if (ctx->mtl_device_ref_count == 0) { + if (ctx->mtl_lock) { + [ctx->mtl_lock release]; + ctx->mtl_lock = nil; + } + if (ctx->mtl_library) { [ctx->mtl_library release]; ctx->mtl_library = nil; @@ -977,7 +994,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); struct ggml_backend_metal_device_context * ctx_dev = dev->context; - id device = ggml_backend_metal_device_acq(ctx_dev); + id device = ctx_dev->mtl_device; GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); @@ -991,9 +1008,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); // load library - if (ctx_dev->mtl_library == nil) { - ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat); + { + [ctx_dev->mtl_lock lock]; + + if (ctx_dev->mtl_library == nil) { + ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat); + } + + [ctx_dev->mtl_lock unlock]; } + id metal_library = ctx_dev->mtl_library; if (metal_library == nil) { GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__); @@ -5284,7 +5308,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) } ggml_backend_metal_buffer_rset_free(ctx); - ggml_backend_metal_device_rel(buffer->buft->device->context); if (ctx->owned) { #if TARGET_OS_OSX @@ -5393,7 +5416,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba } struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context; - id device = ggml_backend_metal_device_acq(ctx_dev); + + GGML_ASSERT(ctx_dev->mtl_device != nil); + + id device = ctx_dev->mtl_device; ctx->all_data = ggml_metal_host_malloc(size_aligned); ctx->all_size = size_aligned; @@ -5416,14 +5442,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); free(ctx); - ggml_backend_metal_device_rel(ctx_dev); return NULL; } if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); free(ctx); - ggml_backend_metal_device_rel(ctx_dev); return NULL; } @@ -5434,17 +5458,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { return 32; + GGML_UNUSED(buft); } static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - id device = ggml_backend_metal_device_acq(buft->device->context); - const size_t max_size = device.maxBufferLength; - ggml_backend_metal_device_rel(buft->device->context); + const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; return max_size; - - GGML_UNUSED(buft); } static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { @@ -5517,7 +5538,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz } struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; - id device = ggml_backend_metal_device_acq(ctx_dev); + + GGML_ASSERT(ctx_dev->mtl_device != nil); + + id device = ctx_dev->mtl_device; // the buffer fits into the max buffer size allowed by the device if (size_aligned <= device.maxBufferLength) { @@ -5573,7 +5597,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); free(ctx); - ggml_backend_metal_device_rel(ctx_dev); return NULL; } @@ -5589,10 +5612,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) { } static void ggml_backend_metal_free(ggml_backend_t backend) { - struct ggml_backend_metal_context * ctx = backend->context; - struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + struct ggml_backend_metal_context * ctx = backend->context; - ggml_backend_metal_device_rel(ctx_dev); ggml_metal_free(ctx); free(backend); @@ -5732,6 +5753,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + GGML_ASSERT(ctx_dev->mtl_device != nil); + return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; } @@ -5751,10 +5774,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { } static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { - // acq/rel just to populate ctx->name in case it hasn't been done yet struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - ggml_backend_metal_device_acq(ctx_dev); - ggml_backend_metal_device_rel(ctx_dev); return ctx_dev->name; } @@ -5762,12 +5782,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { if (@available(macOS 10.12, iOS 16.0, *)) { struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - id device = ggml_backend_metal_device_acq(ctx_dev); + id device = ctx_dev->mtl_device; *total = device.recommendedMaxWorkingSetSize; *free = *total - device.currentAllocatedSize; - - ggml_backend_metal_device_rel(ctx_dev); } else { *free = 1; *total = 1; @@ -5845,7 +5863,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back } struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - id device = ggml_backend_metal_device_acq(ctx_dev); + + GGML_ASSERT(ctx_dev->mtl_device != nil); + + id device = ctx_dev->mtl_device; // the buffer fits into the max buffer size allowed by the device if (size_aligned <= device.maxBufferLength) { @@ -5901,7 +5922,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); free(ctx); - ggml_backend_metal_device_rel(ctx_dev); return NULL; } @@ -5915,8 +5935,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const } static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; + return + buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; GGML_UNUSED(dev); } @@ -6001,8 +6022,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = { /* .get_proc_address = */ ggml_backend_metal_get_proc_address, }; +// called upon program exit +static void ggml_metal_cleanup(void) { + ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main); +} + +// TODO: make thread-safe ggml_backend_reg_t ggml_backend_metal_reg(void) { - // TODO: make this thread-safe somehow? + ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main); + + // register cleanup callback + // TODO: not ideal, but not sure if there is a better way to do this in Objective-C + atexit(ggml_metal_cleanup); + { g_ggml_backend_metal_reg = (struct ggml_backend_reg) { /* .api_version = */ GGML_BACKEND_API_VERSION, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 652b7ca7d..23e520c80 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1057,6 +1057,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { struct vk_instance_t { vk::Instance instance; + bool debug_utils_support = false; // VK_EXT_debug_utils enabled + PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {}; + PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {}; + PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {}; + PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {}; + PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {}; + PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {}; + std::vector device_indices; vk_device devices[GGML_VK_MAX_DEVICES]; }; @@ -1196,6 +1204,14 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } pipeline->compiled = true; + if (vk_instance.debug_utils_support) { + vk::DebugUtilsObjectNameInfoEXT duoni; + duoni.objectType = vk::ObjectType::ePipeline; + duoni.pObjectName = pipeline->name.c_str(); + duoni.objectHandle = reinterpret_cast(static_cast(pipeline->pipeline)); + vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast(duoni)); + } + { std::lock_guard guard(device->mutex); device->pipelines.insert({ pipeline->name, pipeline }); @@ -3585,6 +3601,8 @@ static void ggml_vk_print_gpu_info(size_t idx) { static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); + static void ggml_vk_instance_init() { if (vk_instance_initialized) { return; @@ -3605,7 +3623,7 @@ static void ggml_vk_instance_init() { #ifdef __APPLE__ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); #endif - + const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr; std::vector layers; if (validation_ext) { @@ -3620,6 +3638,9 @@ static void ggml_vk_instance_init() { extensions.push_back("VK_KHR_portability_enumeration"); } #endif + if (debug_utils_ext) { + extensions.push_back("VK_EXT_debug_utils"); + } vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); #ifdef __APPLE__ if (portability_enumeration_ext) { @@ -3643,6 +3664,18 @@ static void ggml_vk_instance_init() { vk_instance.instance = vk::createInstance(instance_create_info); vk_instance_initialized = true; + if (debug_utils_ext) { + vk_instance.debug_utils_support = true; + vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT"); + vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT"); + vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT"); + + } + + size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan @@ -9680,6 +9713,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if (vk_instance.debug_utils_support) { + vk::DebugUtilsLabelEXT dul = {}; + dul.pLabelName = "ggml_backend_vk_graph_compute"; + dul.color = std::array{1.0f, 1.0f, 1.0f, 1.0f}; + vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul)); + } + uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); @@ -10369,6 +10409,22 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve UNUSED(instance_extensions); } +// Extension availability +static bool ggml_vk_instance_debug_utils_ext_available( + const std::vector & instance_extensions) { + // Check for portability enumeration extension for MoltenVK support + for (const auto & properties : instance_extensions) { + if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) { + return true; + } + } + + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl; + return false; + + UNUSED(instance_extensions); +} + static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { switch (props.vendorID) { case VK_VENDOR_ID_INTEL: diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 6c4d3a422..3b08f6134 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -197,13 +197,23 @@ class SpecialVocab: if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'): if not tokenizer_config: special_eos = special_last + elif special_last != special_eos: + if 'eot' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eot', ) + tokenizer_config['eot_token'] = special_eos + elif 'eom' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eom', ) + tokenizer_config['eom_token'] = special_eos + else: + logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!') + tokenizer_config['eos_token'] = special_eos = special_last self.add_special_token['eos'] = True if special_last == special_eos else False if special_last != special_eos: logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing') if tmpl_pair: - seq_start = 1 if tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 - seq_stop = -1 if tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None - if seq_start == 0 or seq_stop is None: + seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 + seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None + if (special_first and seq_start == 0) or (special_last and seq_stop is None): logger.warning('TemplateProcessing leading/trailing special tokens do not match TemplateProcessing') if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]: tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id') diff --git a/src/llama-context.cpp b/src/llama-context.cpp index abbd6ba46..463abb77e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -280,8 +280,8 @@ llama_context::llama_context( // simulate full KV cache - const auto mstate = memory->init_full(); - if (!mstate) { + const auto mctx = memory->init_full(); + if (!mctx) { throw std::runtime_error("failed to initialize KV cache"); } @@ -289,7 +289,7 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -300,7 +300,7 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1, mstate.get()); + auto * gf = graph_reserve(1, 1, 1, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -311,7 +311,7 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) { optimize |= memory_force_optimize; memory_force_optimize = false; - const auto mstate = memory->init_update(this, optimize); - switch (mstate->get_status()) { + const auto mctx = memory->init_update(this, optimize); + switch (mctx->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: { // noop @@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) { } } - if (!mstate->apply()) { + if (!mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); } } // if the memory module did any computation, we have to reserve a new worst-case graph { - const auto mstate = memory->init_full(); - if (!mstate) { - throw std::runtime_error("failed to initialize memory state"); + const auto mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory context"); } const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } @@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { - if (mstate && !mstate->apply()) { - LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); +llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { + if (mctx && !mctx->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } @@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, return nullptr; } - auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx); if (!res) { LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); ret = GGML_STATUS_FAILED; @@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) { // handle any pending defrags/shifts kv_self_update(false); - llama_memory_state_ptr mstate; + llama_memory_context_ptr mctx; while (true) { - mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all); - if (!mstate) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + if (!mctx) { return -2; } - switch (mstate->get_status()) { + switch (mctx->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: { } break; case LLAMA_MEMORY_STATUS_NO_UPDATE: { - LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); + LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); return -2; } @@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) { int64_t n_outputs_prev = 0; do { - const auto & ubatch = mstate->get_ubatch(); + const auto & ubatch = mctx->get_ubatch(); // count the outputs in this ubatch { @@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) { ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_status status; - const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status); + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; - } while (mstate->next()); + } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1292,8 +1292,8 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } -ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { - //LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { + // LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs @@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx); this->n_outputs = save_n_outputs; @@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u } llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype, - const llama_memory_state_i * mstate) { + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_context_i * mctx) { return model.build_graph( { /*.ctx =*/ ctx, @@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.mstate =*/ mstate, + /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter( uint32_t n_outputs_all = n_tokens_all; - auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true); - if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true); + if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); break; } @@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter( uint32_t pos_batch = 0; do { - const auto & ubatch = mstate->get_ubatch(); + const auto & ubatch = mctx->get_ubatch(); n_outputs = ubatch.n_tokens; - if (!mstate->apply()) { - LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); + if (!mctx->apply()) { + LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__); break; } auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get()); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get()); struct ggml_context * ctx_compute_opt; { @@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter( ggml_free(ctx_compute_opt); pos_batch += ubatch.n_tokens; - } while (mstate->next()); + } while (mctx->next()); } } diff --git a/src/llama-context.h b/src/llama-context.h index 7d300c145..9ce05715a 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -18,7 +18,7 @@ class llama_io_read_i; class llama_io_write_i; struct llama_memory_i; -struct llama_memory_state_i; +struct llama_memory_context_i; struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs @@ -93,14 +93,14 @@ struct llama_context { int32_t il_end); // process a single ubatch with a specific graph type - // if memory_state is provided, it will be applied first to the context's memory + // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation // returns nullptr only if ret != GGML_STATUS_SUCCESS llm_graph_result_ptr process_ubatch( - const llama_ubatch & ubatch, - llm_graph_type gtype, - llama_memory_state_i * mstate, - ggml_status & ret); + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_context_i * mctx, + ggml_status & ret); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -197,15 +197,15 @@ public: ggml_status graph_compute(ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size - ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); private: llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype, - const llama_memory_state_i * mstate); + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_context_i * mctx); llm_graph_cb graph_get_cb() const; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7e162c555..48589a50a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { - kv_state->set_input_pos_bucket(pos_bucket, ubatch); + mctx->set_input_pos_bucket(pos_bucket, ubatch); } } @@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_rs = mem_state->get_n_rs(); + const int64_t n_rs = mctx->get_n_rs(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_rs; ++i) { - data[i] = mem_state->s_copy(i); + data[i] = mctx->s_copy(i); } } } @@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } if (self_kq_mask_swa) { - kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } } @@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { - mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } - const int64_t n_rs = mem_state->get_state_recr()->get_n_rs(); + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -345,7 +345,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_rs; ++i) { - data[i] = mem_state->get_state_recr()->s_copy(i); + data[i] = mctx->get_recr()->s_copy(i); } } } @@ -389,7 +389,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : backend_cpu (params.backend_cpu), cvec (params.cvec), loras (params.loras), - mstate (params.mstate), + mctx (params.mctx), cross (params.cross), cb_func (params.cb), res (std::make_unique()) { @@ -950,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, kv_state); + auto inp = std::make_unique(hparams, mctx_cur); - const auto n_kv = kv_state->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); auto & cur = inp->pos_bucket; @@ -982,14 +982,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t } llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { - const auto * mem_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, cparams, mem_state); + auto inp = std::make_unique(hparams, cparams, mctx_cur); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); - const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv(); + const auto n_kv = inp->mctx->get_attn()->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -999,7 +999,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { } { - const auto n_rs = mem_state->get_state_recr()->get_n_rs(); + const auto n_rs = mctx_cur->get_recr()->get_n_rs(); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); @@ -1183,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, cparams, kv_state); + auto inp = std::make_unique(hparams, cparams, mctx_cur); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = kv_state->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1220,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); // store to KV cache { - ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_state->get_k(ctx0, il); - ggml_tensor * v = kv_state->get_v(ctx0, il); + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = mctx_cur->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1270,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state_iswa = static_cast(mstate); + const auto * mctx_iswa = static_cast(mctx); const bool is_swa = hparams.is_swa(il); - const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); + const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); // store to KV cache { - ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_state->get_k(ctx0, il); - ggml_tensor * v = kv_state->get_v(ctx0, il); + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = mctx_cur->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1379,19 +1379,19 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state = static_cast(mstate)->get_state_attn(); + const auto * mctx_cur = static_cast(mctx)->get_attn(); // store to KV cache { - ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); } const auto & kq_mask = inp->get_kq_mask(); ggml_tensor * q = q_cur; - ggml_tensor * k = kv_state->get_k(ctx0, il); - ggml_tensor * v = kv_state->get_v(ctx0, il); + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = mctx_cur->get_v(ctx0, il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); @@ -1412,12 +1412,12 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, cparams, kv_state); + auto inp = std::make_unique(hparams, cparams, mctx_cur); { - const auto n_kv = kv_state->get_base()->get_n_kv(); + const auto n_kv = mctx_cur->get_base()->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1429,7 +1429,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); - const auto n_kv = kv_state->get_swa()->get_n_kv(); + const auto n_kv = mctx_cur->get_swa()->get_n_kv(); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); @@ -1485,11 +1485,11 @@ ggml_tensor * llm_graph_context::build_rs( } llm_graph_input_rs * llm_graph_context::build_rs_inp() const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(kv_state); + auto inp = std::make_unique(mctx_cur); - const auto n_rs = kv_state->get_n_rs(); + const auto n_rs = mctx_cur->get_n_rs(); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); @@ -1504,9 +1504,9 @@ ggml_tensor * llm_graph_context::build_rs( int32_t state_size, int32_t n_seqs, bool avoid_copies) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); } ggml_tensor * llm_graph_context::build_rs( @@ -1516,9 +1516,9 @@ ggml_tensor * llm_graph_context::build_rs( int32_t state_size, int32_t n_seqs, bool avoid_copies) const { - const auto * kv_state = static_cast(mstate)->get_state_recr(); + const auto * mctx_cur = static_cast(mctx)->get_recr(); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( @@ -1526,13 +1526,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_cgraph * gf, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); const auto token_shift_count = hparams.token_shift_count; const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * token_shift_all = kv_state->get_r_l(il); + ggml_tensor * token_shift_all = mctx_cur->get_r_l(il); ggml_tensor * token_shift = build_rs( inp, gf, token_shift_all, @@ -1547,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = kv_state->get_head(); + const auto kv_head = mctx_cur->get_head(); return ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), - ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il))) + ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il))) ); } diff --git a/src/llama-graph.h b/src/llama-graph.h index 9e62fa607..b433f266d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -17,12 +17,12 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -struct llama_memory_state_i; +struct llama_memory_context_i; -class llama_kv_cache_unified_state; -class llama_kv_cache_unified_iswa_state; -class llama_memory_recurrent_state; -class llama_memory_hybrid_state; +class llama_kv_cache_unified_context; +class llama_kv_cache_unified_iswa_context; +class llama_memory_recurrent_context; +class llama_memory_hybrid_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -136,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} + const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -144,7 +144,8 @@ public: ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] const llama_hparams & hparams; - const llama_kv_cache_unified_state * kv_state; + + const llama_kv_cache_unified_context * mctx; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -191,14 +192,14 @@ public: class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {} + llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_memory_recurrent_state * mem_state; + const llama_memory_recurrent_context * mctx; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -238,10 +239,10 @@ public: llm_graph_input_attn_kv_unified( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_state * kv_state) : + const llama_kv_cache_unified_context * mctx) : hparams(hparams), cparams(cparams), - kv_state(kv_state) { + mctx(mctx) { } ~llm_graph_input_attn_kv_unified() = default; @@ -255,7 +256,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified_state * kv_state; + const llama_kv_cache_unified_context * mctx; }; class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { @@ -263,10 +264,10 @@ public: llm_graph_input_attn_kv_unified_iswa( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_iswa_state * kv_state) : + const llama_kv_cache_unified_iswa_context * mctx) : hparams(hparams), cparams(cparams), - kv_state(kv_state) { + mctx(mctx) { } ~llm_graph_input_attn_kv_unified_iswa() = default; @@ -283,7 +284,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified_iswa_state * kv_state; + const llama_kv_cache_unified_iswa_context * mctx; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -306,10 +307,10 @@ public: llm_graph_input_mem_hybrid( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_memory_hybrid_state * mem_state) : + const llama_memory_hybrid_context * mctx) : hparams(hparams), cparams(cparams), - mem_state(mem_state) { + mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; @@ -325,7 +326,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_memory_hybrid_state * mem_state; + const llama_memory_hybrid_context * mctx; }; // @@ -401,10 +402,10 @@ struct llm_graph_params { ggml_backend_sched_t sched; ggml_backend_t backend_cpu; - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_state_i * mstate; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_context_i * mctx; + const llama_cross * cross; uint32_t n_outputs; @@ -453,10 +454,10 @@ struct llm_graph_context { ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_state_i * mstate; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_context_i * mctx; + const llama_cross * cross; const llm_graph_cb & cb_func; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 0ced340de..b9169299c 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { GGML_UNUSED(embd_all); // first try simple split @@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc assert(heads_base.size() == heads_swa.size()); - return std::make_unique( + return std::make_unique( this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } while (false); @@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc assert(heads_base.size() == heads_swa.size()); - return std::make_unique( + return std::make_unique( this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } while (false); // TODO: if we fail again, we should attempt different splitting strategies // but to do that properly, we first have to refactor the batches to be more flexible - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(this); } -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { - return std::make_unique(this, lctx, optimize); +llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } bool llama_kv_cache_unified_iswa::get_can_shift() const { @@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { } // -// llama_kv_cache_unified_iswa_state +// llama_kv_cache_unified_iswa_context // -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} +llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {} -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( +llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv) : - state_base(kv->get_base()->init_full()), - state_swa (kv->get_swa ()->init_full()), - status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { + ctx_base(kv->get_base()->init_full()), + ctx_swa (kv->get_swa ()->init_full()), + status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( +llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, llama_context * lctx, bool optimize) : - state_base(kv->get_base()->init_update(lctx, optimize)), - state_swa (kv->get_swa ()->init_update(lctx, optimize)), - status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { + ctx_base(kv->get_base()->init_update(lctx, optimize)), + ctx_swa (kv->get_swa ()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( +llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, std::vector heads_base, std::vector heads_swa, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)), - state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)), - status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { + ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)), + ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)), + status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; +llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default; -bool llama_kv_cache_unified_iswa_state::next() { +bool llama_kv_cache_unified_iswa_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - state_base->next(); - state_swa ->next(); + ctx_base->next(); + ctx_swa ->next(); if (++i_next >= ubatches.size()) { return false; @@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() { return true; } -bool llama_kv_cache_unified_iswa_state::apply() { +bool llama_kv_cache_unified_iswa_context::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); bool res = true; - res = res & state_base->apply(); - res = res & state_swa ->apply(); + res = res & ctx_base->apply(); + res = res & ctx_swa ->apply(); return res; } -llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { +llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const { return status; } -const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { +const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { +const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return static_cast(state_base.get()); + return static_cast(ctx_base.get()); } -const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { +const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return static_cast(state_swa.get()); + return static_cast(ctx_swa.get()); } diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 071041585..46c1ed614 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -31,14 +31,14 @@ public: // llama_memory_i // - llama_memory_state_ptr init_batch( + llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; - llama_memory_state_ptr init_full() override; + llama_memory_context_ptr init_full() override; - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -72,32 +72,32 @@ private: std::unique_ptr kv_swa; }; -class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { public: // used for errors - llama_kv_cache_unified_iswa_state(llama_memory_status status); + llama_kv_cache_unified_iswa_context(llama_memory_status status); - // used to create a full-cache state - llama_kv_cache_unified_iswa_state( + // used to create a full-cache context + llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv); - // used to create an update state - llama_kv_cache_unified_iswa_state( + // used to create an update context + llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, llama_context * lctx, bool optimize); - // used to create a state from a batch - llama_kv_cache_unified_iswa_state( + // used to create a batch processing context from a batch + llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, std::vector heads_base, std::vector heads_swa, std::vector ubatches); - virtual ~llama_kv_cache_unified_iswa_state(); + virtual ~llama_kv_cache_unified_iswa_context(); // - // llama_memory_state_i + // llama_memory_context_i // bool next() override; @@ -107,11 +107,11 @@ public: const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_unified_iswa_state specific API + // llama_kv_cache_unified_iswa_context specific API // - const llama_kv_cache_unified_state * get_base() const; - const llama_kv_cache_unified_state * get_swa() const; + const llama_kv_cache_unified_context * get_base() const; + const llama_kv_cache_unified_context * get_swa() const; private: //llama_kv_cache_unified_iswa * kv; @@ -121,8 +121,8 @@ private: std::vector ubatches; - const llama_memory_state_ptr state_base; - const llama_memory_state_ptr state_swa; + const llama_memory_context_ptr ctx_base; + const llama_memory_context_ptr ctx_swa; const llama_memory_status status; }; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index f4459cbc1..a75f40ac8 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -307,7 +307,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } -llama_memory_state_ptr llama_kv_cache_unified::init_batch( +llama_memory_context_ptr llama_kv_cache_unified::init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { @@ -332,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( break; } - return std::make_unique( + return std::make_unique( this, std::move(heads), std::move(ubatches)); } while (false); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_state_ptr llama_kv_cache_unified::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(this); } -llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { +llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { bool do_shift = get_has_shift(); defrag_info dinfo; @@ -373,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, } } - return std::make_unique(this, lctx, do_shift, std::move(dinfo)); + return std::make_unique(this, lctx, do_shift, std::move(dinfo)); } llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { @@ -1710,18 +1710,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell } // -// llama_kv_cache_unified_state +// llama_kv_cache_unified_context // -llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} +llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {} -llama_kv_cache_unified_state::llama_kv_cache_unified_state( +llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); head = 0; } -llama_kv_cache_unified_state::llama_kv_cache_unified_state( +llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, @@ -1731,15 +1731,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state( } } -llama_kv_cache_unified_state::llama_kv_cache_unified_state( +llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_kv_cache_unified::ubatch_heads heads, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { } -llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; +llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; -bool llama_kv_cache_unified_state::next() { +bool llama_kv_cache_unified_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); if (++i_next >= ubatches.size()) { @@ -1749,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() { return true; } -bool llama_kv_cache_unified_state::apply() { +bool llama_kv_cache_unified_context::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); // no ubatches -> this is a KV cache update @@ -1767,45 +1767,45 @@ bool llama_kv_cache_unified_state::apply() { return true; } -llama_memory_status llama_kv_cache_unified_state::get_status() const { +llama_memory_status llama_kv_cache_unified_context::get_status() const { return status; } -const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { +const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -uint32_t llama_kv_cache_unified_state::get_n_kv() const { +uint32_t llama_kv_cache_unified_context::get_n_kv() const { return n_kv; } -ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv); } -ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { return kv->get_v(ctx, il, n_kv); } -ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { +ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { return kv->cpy_k(ctx, k_cur, il, head); } -ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { +ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { return kv->cpy_v(ctx, v_cur, il, head); } -void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { +void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } -void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { +void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { kv->set_input_kq_mask(dst, ubatch, causal_attn); } -void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 156064004..4c53f1273 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -56,14 +56,14 @@ public: // llama_memory_i // - llama_memory_state_ptr init_batch( + llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; - llama_memory_state_ptr init_full() override; + llama_memory_context_ptr init_full() override; - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -208,36 +208,36 @@ private: bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -class llama_kv_cache_unified_state : public llama_memory_state_i { +class llama_kv_cache_unified_context : public llama_memory_context_i { public: // some shorthands using ubatch_heads = llama_kv_cache_unified::ubatch_heads; using defrag_info = llama_kv_cache_unified::defrag_info; // used for errors - llama_kv_cache_unified_state(llama_memory_status status); + llama_kv_cache_unified_context(llama_memory_status status); - // used to create a full-cache state - llama_kv_cache_unified_state( + // used to create a full-cache context + llama_kv_cache_unified_context( llama_kv_cache_unified * kv); - // used to create an update state - llama_kv_cache_unified_state( + // used to create an update context + llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, defrag_info dinfo); - // used to create a decode state from a batch - llama_kv_cache_unified_state( + // used to create a batch procesing context from a batch + llama_kv_cache_unified_context( llama_kv_cache_unified * kv, ubatch_heads heads, std::vector ubatches); - virtual ~llama_kv_cache_unified_state(); + virtual ~llama_kv_cache_unified_context(); // - // llama_memory_state_i + // llama_memory_context_i // bool next() override; @@ -247,7 +247,7 @@ public: const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_unified_state specific API + // llama_kv_cache_unified_context specific API // uint32_t get_n_kv() const; @@ -272,7 +272,7 @@ private: llama_context * lctx; // - // update state + // update context // bool do_shift = false; @@ -280,7 +280,7 @@ private: defrag_info dinfo; // - // batch processing state + // batch processing context // // the index of the next ubatch to process diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 1b1668681..15cde98d1 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max )) {} -llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { do { balloc.split_reset(); @@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball // prepare the recurrent batches first if (!mem_recr->prepare(ubatches)) { - // TODO: will the recurrent cache be in an undefined state at this point? + // TODO: will the recurrent cache be in an undefined context at this point? LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } // prepare the attention cache auto heads_attn = mem_attn->prepare(ubatches); if (heads_attn.empty()) { LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique( + return std::make_unique( this, std::move(heads_attn), std::move(ubatches)); } while(false); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_state_ptr llama_memory_hybrid::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_memory_hybrid::init_full() { + return std::make_unique(this); } -llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) { - return std::make_unique(this, lctx, optimize); +llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } bool llama_memory_hybrid::get_can_shift() const { @@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const { return mem_recr.get(); } -llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {} +llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {} -llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) : - state_attn(mem->get_mem_attn()->init_full()), - state_recr(mem->get_mem_recr()->init_full()), - status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { +llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } -llama_memory_hybrid_state::llama_memory_hybrid_state( +llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid * mem, llama_context * lctx, bool optimize) : - state_attn(mem->get_mem_attn()->init_update(lctx, optimize)), - state_recr(mem->get_mem_recr()->init_update(lctx, optimize)), - status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } -llama_memory_hybrid_state::llama_memory_hybrid_state( +llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid * mem, std::vector heads_attn, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), - state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)), - status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { + ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } -bool llama_memory_hybrid_state::next() { +bool llama_memory_hybrid_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - state_attn->next(); - state_recr->next(); + ctx_attn->next(); + ctx_recr->next(); if (++i_next >= ubatches.size()) { return false; @@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() { return true; } -bool llama_memory_hybrid_state::apply() { +bool llama_memory_hybrid_context::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); bool res = true; - res = res & state_attn->apply(); - res = res & state_recr->apply(); + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); return res; } -llama_memory_status llama_memory_hybrid_state::get_status() const { +llama_memory_status llama_memory_hybrid_context::get_status() const { return status; } -const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const { +const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const { - return static_cast(state_attn.get()); +const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const { + return static_cast(ctx_attn.get()); } -const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const { - return static_cast(state_recr.get()); +const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const { + return static_cast(ctx_recr.get()); } diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 4d27ab896..f0c2420e9 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -49,14 +49,14 @@ public: // llama_memory_i // - llama_memory_state_ptr init_batch( + llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; - llama_memory_state_ptr init_full() override; + llama_memory_context_ptr init_full() override; - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -90,27 +90,27 @@ private: const std::unique_ptr mem_recr; }; -class llama_memory_hybrid_state : public llama_memory_state_i { +class llama_memory_hybrid_context : public llama_memory_context_i { public: // init failure - explicit llama_memory_hybrid_state(llama_memory_status status); + explicit llama_memory_hybrid_context(llama_memory_status status); // init full - explicit llama_memory_hybrid_state(llama_memory_hybrid * mem); + explicit llama_memory_hybrid_context(llama_memory_hybrid * mem); // init update - explicit llama_memory_hybrid_state( + explicit llama_memory_hybrid_context( llama_memory_hybrid * mem, llama_context * lctx, bool optimize); // init success - llama_memory_hybrid_state( + llama_memory_hybrid_context( llama_memory_hybrid * mem, std::vector heads_attn, std::vector ubatches); - ~llama_memory_hybrid_state() = default; + ~llama_memory_hybrid_context() = default; bool next() override; bool apply() override; @@ -119,11 +119,11 @@ public: const llama_ubatch & get_ubatch() const override; // - // llama_memory_hybrid_state + // llama_memory_hybrid_context // - const llama_kv_cache_unified_state * get_state_attn() const; - const llama_memory_recurrent_state * get_state_recr() const; + const llama_kv_cache_unified_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; private: // the index of the next ubatch to process @@ -131,8 +131,8 @@ private: std::vector ubatches; - const llama_memory_state_ptr state_attn; - const llama_memory_state_ptr state_recr; + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; const llama_memory_status status; }; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 8055bd2a7..74a920204 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -362,7 +362,7 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { std::vector ubatches; while (true) { @@ -383,21 +383,21 @@ llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & b } if (!prepare(ubatches)) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique(this, std::move(ubatches)); + return std::make_unique(this, std::move(ubatches)); } -llama_memory_state_ptr llama_memory_recurrent::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_memory_recurrent::init_full() { + return std::make_unique(this); } -llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) { +llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) { GGML_UNUSED(lctx); GGML_UNUSED(optimize); - return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); + return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); } bool llama_memory_recurrent::prepare(const std::vector & ubatches) { @@ -1040,22 +1040,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell } // -// llama_memory_recurrent_state +// llama_memory_recurrent_context // -llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {} +llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} -llama_memory_recurrent_state::llama_memory_recurrent_state( +llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { } -llama_memory_recurrent_state::llama_memory_recurrent_state( +llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} -llama_memory_recurrent_state::~llama_memory_recurrent_state() = default; +llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; -bool llama_memory_recurrent_state::next() { +bool llama_memory_recurrent_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); if (++i_next >= ubatches.size()) { @@ -1065,7 +1065,7 @@ bool llama_memory_recurrent_state::next() { return true; } -bool llama_memory_recurrent_state::apply() { +bool llama_memory_recurrent_context::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); mem->find_slot(ubatches[i_next]); @@ -1073,40 +1073,40 @@ bool llama_memory_recurrent_state::apply() { return true; } -llama_memory_status llama_memory_recurrent_state::get_status() const { +llama_memory_status llama_memory_recurrent_context::get_status() const { return status; } -const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const { +const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -uint32_t llama_memory_recurrent_state::get_n_rs() const { +uint32_t llama_memory_recurrent_context::get_n_rs() const { return is_full ? mem->size : mem->n; } -uint32_t llama_memory_recurrent_state::get_head() const { +uint32_t llama_memory_recurrent_context::get_head() const { return is_full ? 0 : mem->head; } -int32_t llama_memory_recurrent_state::get_rs_z() const { +int32_t llama_memory_recurrent_context::get_rs_z() const { return is_full ? 0 : mem->rs_z; } -uint32_t llama_memory_recurrent_state::get_size() const { +uint32_t llama_memory_recurrent_context::get_size() const { return mem->size; } -ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const { +ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { return mem->r_l[il]; } -ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const { +ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { return mem->s_l[il]; } -int32_t llama_memory_recurrent_state::s_copy(int i) const { +int32_t llama_memory_recurrent_context::s_copy(int i) const { return mem->cells[i + mem->head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index be58dae7c..4d094f9a0 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -11,8 +11,8 @@ // llama_memory_recurrent // -// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i -// see the implementation of llama_kv_cache_unified_state_i for an example how to do it +// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i +// see the implementation of llama_kv_cache_unified_context_i for an example how to do it class llama_memory_recurrent : public llama_memory_i { public: @@ -34,14 +34,14 @@ public: // llama_memory_i // - llama_memory_state_ptr init_batch( + llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; - llama_memory_state_ptr init_full() override; + llama_memory_context_ptr init_full() override; - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; void clear(bool data) override; @@ -125,24 +125,24 @@ private: bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -class llama_memory_recurrent_state : public llama_memory_state_i { +class llama_memory_recurrent_context : public llama_memory_context_i { public: // used for errors - llama_memory_recurrent_state(llama_memory_status status); + llama_memory_recurrent_context(llama_memory_status status); - // used to create a full-cache state - llama_memory_recurrent_state( + // used to create a full-cache or update context + llama_memory_recurrent_context( llama_memory_recurrent * mem); - // used to create a state from a batch - llama_memory_recurrent_state( + // used to create a batch processing context from a batch + llama_memory_recurrent_context( llama_memory_recurrent * mem, std::vector ubatches); - virtual ~llama_memory_recurrent_state(); + virtual ~llama_memory_recurrent_context(); // - // llama_memory_state_i + // llama_memory_context_i // bool next() override; @@ -152,7 +152,7 @@ public: const llama_ubatch & get_ubatch() const override; // - // llama_memory_recurrent_state specific API + // llama_memory_recurrent_context specific API // uint32_t get_n_rs() const; diff --git a/src/llama-memory.h b/src/llama-memory.h index d2ef0c2a3..16b7e5ee2 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -3,7 +3,6 @@ #include "llama.h" #include -#include struct llama_ubatch; @@ -28,23 +27,21 @@ enum llama_memory_status { LLAMA_MEMORY_STATUS_FAILED_COMPUTE, }; -// helper function for combining the status of two memory states +// helper function for combining the status of two memory contexts // useful for implementing hybrid memory types (e.g. iSWA) llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); -// the interface for managing the memory state during batch processing +// the interface for managing the memory context during batch processing // this interface is implemented per memory type. see: -// - llama_kv_cache_unified_state -// - llama_kv_cache_unified_iswa_state +// - llama_kv_cache_unified_context +// - llama_kv_cache_unified_iswa_context // ... // -// the only method that can mutate the memory and the memory state is llama_memory_i::apply() -// -// TODO: rename to llama_memory_context_i ? -struct llama_memory_state_i { - virtual ~llama_memory_state_i() = default; +// the only method that should mutate the memory and the memory context is llama_memory_i::apply() +struct llama_memory_context_i { + virtual ~llama_memory_context_i() = default; - // consume the current ubatch from the state and proceed to the next one + // consume the current ubatch from the context and proceed to the next one // return false if we are done virtual bool next() = 0; @@ -55,11 +52,11 @@ struct llama_memory_state_i { // get the current ubatch virtual const llama_ubatch & get_ubatch() const = 0; - // get the status of the memory state - used for error handling and checking if any updates would be applied + // get the status of the memory context - used for error handling and checking if any updates would be applied virtual llama_memory_status get_status() const = 0; }; -using llama_memory_state_ptr = std::unique_ptr; +using llama_memory_context_ptr = std::unique_ptr; // general concept of LLM memory // the KV cache is a type of LLM memory, but there can be other types @@ -67,19 +64,19 @@ struct llama_memory_i { virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache - // return a state object containing the ubatches and KV cache state required to process them - // check the llama_memory_state_i::get_status() for the result - virtual llama_memory_state_ptr init_batch( + // return a context object containing the ubatches and memory state required to process them + // check the llama_memory_context_i::get_status() for the result + virtual llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) = 0; // simulate full cache, used for allocating worst-case compute buffers - virtual llama_memory_state_ptr init_full() = 0; + virtual llama_memory_context_ptr init_full() = 0; // prepare for any pending memory updates, such as shifts, defrags, etc. // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update - virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; + virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; // getters virtual bool get_can_shift() const = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fb1a06c1d..f4cd9fedc 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9271,9 +9271,9 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * cur, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); - const auto kv_head = kv_state->get_head(); + const auto kv_head = mctx_cur->get_head(); const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -9291,8 +9291,8 @@ struct llm_build_mamba : public llm_graph_context { GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - ggml_tensor * conv_states_all = kv_state->get_r_l(il); - ggml_tensor * ssm_states_all = kv_state->get_s_l(il); + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); // (ab)using the KV cache to store the states ggml_tensor * conv = build_rs( @@ -12016,7 +12016,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * x_prev, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12026,7 +12026,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = kv_state->get_head(); + const auto kv_head = mctx_cur->get_head(); const auto & layer = model.layers[il]; @@ -12138,7 +12138,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * wkv_state = build_rs( - inp, gf, kv_state->get_s_l(il), + inp, gf, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); ggml_tensor * wkv_output; @@ -12157,9 +12157,9 @@ struct llm_build_rwkv6_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_state->get_s_l(il), + mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs, - hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il)) + hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il)) ) ) ); @@ -12413,7 +12413,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const auto * kv_state = static_cast(mstate); + const auto * mctx_cur = static_cast(mctx); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12422,7 +12422,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { const auto head_count = n_embd / head_size; const auto n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = kv_state->get_head(); + const auto kv_head = mctx_cur->get_head(); const auto & layer = model.layers[il]; @@ -12493,7 +12493,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_rs( - inp, gf, kv_state->get_s_l(il), + inp, gf, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12507,9 +12507,9 @@ struct llm_build_rwkv7_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_state->get_s_l(il), + mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs, - hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il)) + hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il)) ) ) ); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index ca2575ae3..afe83e128 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2257,6 +2257,9 @@ struct clip_model_loader { { hparams.rope_theta = 10000.0f; hparams.warmup_image_size = hparams.patch_size * 8; + // Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM + // ref: https://github.com/ggml-org/llama.cpp/issues/14310 + hparams.image_size = 1024; get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); } break; case PROJECTOR_TYPE_GEMMA3: