mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # README.md # ci/run.sh # examples/embedding/embedding.cpp # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt # src/CMakeLists.txt
This commit is contained in:
commit
d33c88b1f4
18 changed files with 508 additions and 276 deletions
|
@ -3709,8 +3709,7 @@ class BertModel(TextModel):
|
||||||
self._try_set_pooling_type()
|
self._try_set_pooling_type()
|
||||||
|
|
||||||
if self.cls_out_labels:
|
if self.cls_out_labels:
|
||||||
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
|
self.gguf_writer.add_classifier_output_labels([v for k, v in sorted(self.cls_out_labels.items())])
|
||||||
self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())])
|
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||||
|
|
|
@ -212,6 +212,7 @@ enum vk_device_architecture {
|
||||||
AMD_RDNA1,
|
AMD_RDNA1,
|
||||||
AMD_RDNA2,
|
AMD_RDNA2,
|
||||||
AMD_RDNA3,
|
AMD_RDNA3,
|
||||||
|
INTEL_XE2,
|
||||||
};
|
};
|
||||||
|
|
||||||
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||||
|
@ -262,6 +263,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
||||||
}
|
}
|
||||||
return vk_device_architecture::AMD_RDNA2;
|
return vk_device_architecture::AMD_RDNA2;
|
||||||
}
|
}
|
||||||
|
} else if (props.vendorID == VK_VENDOR_ID_INTEL) {
|
||||||
|
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||||
|
|
||||||
|
bool subgroup_size_control = false;
|
||||||
|
|
||||||
|
for (const auto& properties : ext_props) {
|
||||||
|
if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||||
|
subgroup_size_control = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!subgroup_size_control) {
|
||||||
|
return vk_device_architecture::OTHER;
|
||||||
|
}
|
||||||
|
|
||||||
|
vk::PhysicalDeviceProperties2 props2;
|
||||||
|
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||||
|
|
||||||
|
props2.pNext = &subgroup_size_control_props;
|
||||||
|
device.getProperties2(&props2);
|
||||||
|
|
||||||
|
if (subgroup_size_control_props.minSubgroupSize == 16) {
|
||||||
|
// Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.
|
||||||
|
// Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.
|
||||||
|
// https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
|
||||||
|
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
||||||
|
return vk_device_architecture::INTEL_XE2;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return vk_device_architecture::OTHER;
|
return vk_device_architecture::OTHER;
|
||||||
}
|
}
|
||||||
|
@ -4103,7 +4132,33 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
|
template <typename T> size_t push_constant_size(const T &t) {
|
||||||
|
static_assert(std::is_class<T>::value, "T must be a struct/class");
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
return sizeof(T);
|
||||||
|
}
|
||||||
|
template <typename T> size_t push_constant_size(const std::vector<T> &t) {
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
return sizeof(T) * t.size();
|
||||||
|
}
|
||||||
|
template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
return sizeof(T) * N;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> const T *push_constant_data(const T &t) {
|
||||||
|
static_assert(std::is_class<T>::value, "T must be a struct/class");
|
||||||
|
return &t;
|
||||||
|
}
|
||||||
|
template <typename T> const T *push_constant_data(const std::vector<T> &t) {
|
||||||
|
return t.data();
|
||||||
|
}
|
||||||
|
template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {
|
||||||
|
return t.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
|
||||||
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
|
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
|
||||||
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
|
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
|
||||||
const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
|
const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
|
||||||
|
@ -4119,7 +4174,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
||||||
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
||||||
ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
|
ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
|
||||||
|
|
||||||
subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
|
subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
|
||||||
subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
|
subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
|
||||||
subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
|
subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
|
||||||
pipeline->layout,
|
pipeline->layout,
|
||||||
|
@ -4582,7 +4637,7 @@ static void ggml_vk_matmul(
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
if (split_k == 1) {
|
if (split_k == 1) {
|
||||||
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4590,10 +4645,10 @@ static void ggml_vk_matmul(
|
||||||
|
|
||||||
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
|
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||||
// Make sure enough workgroups get assigned for split k to work
|
// Make sure enough workgroups get assigned for split k to work
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
|
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
|
||||||
|
@ -4641,7 +4696,7 @@ static void ggml_vk_matmul_id(
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
|
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
|
||||||
nei0, nei1, nbi1, ne11, padded_n };
|
nei0, nei1, nbi1, ne11, padded_n };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
||||||
|
@ -4762,7 +4817,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
||||||
};
|
};
|
||||||
init_pushconst_fastdiv(pc);
|
init_pushconst_fastdiv(pc);
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
|
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
|
||||||
|
@ -4781,7 +4836,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
@ -4981,7 +5036,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
} else if (qx_needs_dequant) {
|
} else if (qx_needs_dequant) {
|
||||||
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig) {
|
||||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||||
|
@ -5197,7 +5252,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||||
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
|
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
|
||||||
sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
|
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
@ -5285,7 +5340,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z });
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
@ -5368,7 +5423,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
||||||
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
||||||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
@ -5584,7 +5639,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
|
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
|
||||||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig) {
|
||||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||||
|
@ -5804,7 +5859,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||||
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
|
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
|
||||||
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
|
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
|
||||||
sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
|
pc, { groups_x, (uint32_t)nei0, groups_z });
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
@ -6154,7 +6209,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
||||||
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
||||||
// cancel out the divide by wg_denoms[0].
|
// cancel out the divide by wg_denoms[0].
|
||||||
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
|
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
|
||||||
|
@ -6163,7 +6218,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||||
},
|
},
|
||||||
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
|
pc2, { (uint32_t)ne1, 1, 1 });
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||||
{
|
{
|
||||||
|
@ -6173,7 +6228,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
||||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||||
},
|
},
|
||||||
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
|
pc, { workgroups_x, workgroups_y, workgroups_z });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6851,7 +6906,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
||||||
// Empty src2 is possible in rope, but the shader needs a buffer
|
// Empty src2 is possible in rope, but the shader needs a buffer
|
||||||
vk_subbuffer subbuf_z;
|
vk_subbuffer subbuf_z;
|
||||||
|
@ -6862,26 +6917,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
} else if (op == GGML_OP_IM2COL) {
|
} else if (op == GGML_OP_IM2COL) {
|
||||||
// im2col uses only src1 and dst buffers
|
// im2col uses only src1 and dst buffers
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
} else if (op == GGML_OP_COUNT_EQUAL) {
|
} else if (op == GGML_OP_COUNT_EQUAL) {
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
// count_equal assumes that destination buffer is initialized with zeroes
|
// count_equal assumes that destination buffer is initialized with zeroes
|
||||||
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
} else if (use_src2) {
|
} else if (use_src2) {
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
} else if (use_src1) {
|
} else if (use_src1) {
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7050,7 +7105,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
}, pc, elements);
|
||||||
} else if (version == 7) {
|
} else if (version == 7) {
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||||
|
@ -7061,7 +7116,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||||
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
|
}, pc, elements);
|
||||||
} else {
|
} else {
|
||||||
// shouldn't happen
|
// shouldn't happen
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -7198,7 +7253,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
|
||||||
vk_subbuffer{ d_GM, gm_offset, gm_size },
|
vk_subbuffer{ d_GM, gm_offset, gm_size },
|
||||||
vk_subbuffer{ d_GV, gv_offset, gv_size },
|
vk_subbuffer{ d_GV, gv_offset, gv_size },
|
||||||
vk_subbuffer{ d_P, p_offset, p_size },
|
vk_subbuffer{ d_P, p_offset, p_size },
|
||||||
}, sizeof(vk_op_push_constants), &pc, elements);
|
}, pc, elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
@ -8087,7 +8142,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
||||||
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
||||||
ggml_vk_ctx_begin(ctx->device, subctx);
|
ggml_vk_ctx_begin(ctx->device, subctx);
|
||||||
const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
|
const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
|
ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});
|
||||||
ggml_vk_ctx_end(subctx);
|
ggml_vk_ctx_end(subctx);
|
||||||
|
|
||||||
auto begin = std::chrono::high_resolution_clock::now();
|
auto begin = std::chrono::high_resolution_clock::now();
|
||||||
|
@ -10261,8 +10316,9 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
||||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||||
switch (props.vendorID) {
|
switch (props.vendorID) {
|
||||||
case VK_VENDOR_ID_INTEL:
|
case VK_VENDOR_ID_INTEL:
|
||||||
// Intel drivers don't support coopmat properly yet
|
// Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
|
||||||
return false;
|
// while some older hardware (ex. Arc A770) has performance regressions
|
||||||
|
return arch == vk_device_architecture::INTEL_XE2;
|
||||||
case VK_VENDOR_ID_AMD:
|
case VK_VENDOR_ID_AMD:
|
||||||
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
||||||
// Workaround for AMD proprietary driver reporting support on all GPUs
|
// Workaround for AMD proprietary driver reporting support on all GPUs
|
||||||
|
|
|
@ -935,6 +935,9 @@ class GGUFWriter:
|
||||||
def add_eom_token_id(self, id: int) -> None:
|
def add_eom_token_id(self, id: int) -> None:
|
||||||
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
|
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
|
||||||
|
|
||||||
|
def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
|
||||||
|
self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
|
||||||
|
|
||||||
# for vision models
|
# for vision models
|
||||||
|
|
||||||
def add_clip_has_vision_encoder(self, value: bool) -> None:
|
def add_clip_has_vision_encoder(self, value: bool) -> None:
|
||||||
|
|
109
include/llama.h
109
include/llama.h
|
@ -64,7 +64,10 @@ extern "C" {
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
struct llama_sampler;
|
struct llama_sampler;
|
||||||
struct llama_kv_cache;
|
|
||||||
|
typedef struct llama_memory_i * llama_memory_t;
|
||||||
|
|
||||||
|
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
|
||||||
|
|
||||||
typedef int32_t llama_pos;
|
typedef int32_t llama_pos;
|
||||||
typedef int32_t llama_token;
|
typedef int32_t llama_token;
|
||||||
|
@ -496,9 +499,11 @@ extern "C" {
|
||||||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||||
|
|
||||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||||
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
|
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
|
||||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
|
||||||
|
|
||||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||||
|
|
||||||
|
@ -512,6 +517,13 @@ extern "C" {
|
||||||
// Get the model's RoPE frequency scaling factor
|
// Get the model's RoPE frequency scaling factor
|
||||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Returns the number of classifier outputs (only valid for classifier models)
|
||||||
|
// Undefined behavior for non-classifier models
|
||||||
|
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
|
||||||
|
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
|
||||||
|
|
||||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
|
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
|
||||||
|
|
||||||
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
||||||
|
@ -612,7 +624,78 @@ extern "C" {
|
||||||
int32_t il_end);
|
int32_t il_end);
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache
|
// Memory
|
||||||
|
//
|
||||||
|
|
||||||
|
// Clear the memory contents
|
||||||
|
LLAMA_API void llama_memory_clear(llama_memory_t mem);
|
||||||
|
|
||||||
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
|
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||||
|
// seq_id < 0 : match any sequence
|
||||||
|
// p0 < 0 : [0, p1]
|
||||||
|
// p1 < 0 : [p0, inf)
|
||||||
|
LLAMA_API bool llama_memory_seq_rm(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1);
|
||||||
|
|
||||||
|
// Copy all tokens that belong to the specified sequence to another sequence
|
||||||
|
// p0 < 0 : [0, p1]
|
||||||
|
// p1 < 0 : [p0, inf)
|
||||||
|
LLAMA_API void llama_memory_seq_cp(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id_src,
|
||||||
|
llama_seq_id seq_id_dst,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1);
|
||||||
|
|
||||||
|
// Removes all tokens that do not belong to the specified sequence
|
||||||
|
LLAMA_API void llama_memory_seq_keep(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
|
// p0 < 0 : [0, p1]
|
||||||
|
// p1 < 0 : [p0, inf)
|
||||||
|
LLAMA_API void llama_memory_seq_add(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
llama_pos delta);
|
||||||
|
|
||||||
|
// Integer division of the positions by factor of `d > 1`
|
||||||
|
// p0 < 0 : [0, p1]
|
||||||
|
// p1 < 0 : [p0, inf)
|
||||||
|
LLAMA_API void llama_memory_seq_div(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d);
|
||||||
|
|
||||||
|
// Returns the smallest position present in the memory for the specified sequence
|
||||||
|
// This is typically non-zero only for SWA caches
|
||||||
|
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||||
|
// Return -1 if the sequence is empty
|
||||||
|
LLAMA_API llama_pos llama_memory_seq_pos_min(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Returns the largest position present in the memory for the specified sequence
|
||||||
|
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
|
||||||
|
// Return -1 if the sequence is empty
|
||||||
|
LLAMA_API llama_pos llama_memory_seq_pos_max(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Check if the memory supports shifting
|
||||||
|
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
||||||
|
|
||||||
|
//
|
||||||
|
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
|
||||||
//
|
//
|
||||||
|
|
||||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||||
|
@ -626,7 +709,7 @@ extern "C" {
|
||||||
|
|
||||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||||
LLAMA_API void llama_kv_self_clear(
|
LLAMA_API void llama_kv_self_clear(
|
||||||
struct llama_context * ctx);
|
struct llama_context * ctx);
|
||||||
|
|
||||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||||
|
@ -697,14 +780,14 @@ extern "C" {
|
||||||
// Defragment the KV cache
|
// Defragment the KV cache
|
||||||
// This will be applied:
|
// This will be applied:
|
||||||
// - lazily on next llama_decode()
|
// - lazily on next llama_decode()
|
||||||
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
|
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
|
||||||
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
|
||||||
|
|
||||||
// Check if the context supports KV cache shifting
|
// Check if the context supports KV cache shifting
|
||||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||||
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
|
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
|
||||||
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
"simply remove this call, updates are applied lazily on the next llama_decode()");
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -712,7 +795,7 @@ extern "C" {
|
||||||
//
|
//
|
||||||
|
|
||||||
// Returns the *actual* size in bytes of the state
|
// Returns the *actual* size in bytes of the state
|
||||||
// (logits, embedding and kv_cache)
|
// (logits, embedding and memory)
|
||||||
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
||||||
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
||||||
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
||||||
|
@ -768,12 +851,12 @@ extern "C" {
|
||||||
size_t n_token_count),
|
size_t n_token_count),
|
||||||
"use llama_state_save_file instead");
|
"use llama_state_save_file instead");
|
||||||
|
|
||||||
// Get the exact size needed to copy the KV cache of a single sequence
|
// Get the exact size needed to copy the state of a single sequence
|
||||||
LLAMA_API size_t llama_state_seq_get_size(
|
LLAMA_API size_t llama_state_seq_get_size(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Copy the KV cache of a single sequence into the specified buffer
|
// Copy the state of a single sequence into the specified buffer
|
||||||
LLAMA_API size_t llama_state_seq_get_data(
|
LLAMA_API size_t llama_state_seq_get_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst,
|
uint8_t * dst,
|
||||||
|
@ -839,16 +922,16 @@ extern "C" {
|
||||||
// For encode-decoder contexts, processes the batch using the encoder.
|
// For encode-decoder contexts, processes the batch using the encoder.
|
||||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||||
// 0 - success
|
// 0 - success
|
||||||
// < 0 - error. the KV cache state is restored to the state before this call
|
// < 0 - error. the memory state is restored to the state before this call
|
||||||
LLAMA_API int32_t llama_encode(
|
LLAMA_API int32_t llama_encode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch);
|
struct llama_batch batch);
|
||||||
|
|
||||||
// Process a batch of tokens.
|
// Process a batch of tokens.
|
||||||
// Requires KV cache.
|
// Requires the context to have a memory.
|
||||||
// For encode-decoder contexts, processes the batch using the decoder.
|
// For encode-decoder contexts, processes the batch using the decoder.
|
||||||
// Positive return values does not mean a fatal error, but rather a warning.
|
// Positive return values does not mean a fatal error, but rather a warning.
|
||||||
// Upon non-zero return values, the KV cache state is restored to the state before this call
|
// Upon non-zero return values, the memory state is restored to the state before this call
|
||||||
// 0 - success
|
// 0 - success
|
||||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||||
// 2 - aborted
|
// 2 - aborted
|
||||||
|
@ -919,7 +1002,7 @@ extern "C" {
|
||||||
|
|
||||||
// Get the embeddings for a sequence id
|
// Get the embeddings for a sequence id
|
||||||
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
|
||||||
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
|
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
|
||||||
// otherwise: float[n_embd] (1-dimensional)
|
// otherwise: float[n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
#include "llama-io.h"
|
#include "llama-io.h"
|
||||||
|
#include "llama-memory.h"
|
||||||
#include "llama-mmap.h"
|
#include "llama-mmap.h"
|
||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
#include "llama-kv-cache.h"
|
|
||||||
|
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
@ -277,10 +277,9 @@ llama_context::llama_context(
|
||||||
int n_nodes_tg = -1;
|
int n_nodes_tg = -1;
|
||||||
|
|
||||||
// simulate full KV cache
|
// simulate full KV cache
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
||||||
|
|
||||||
const auto kv_state = kv_self->init_full();
|
const auto mstate = memory->init_full();
|
||||||
if (!kv_state) {
|
if (!mstate) {
|
||||||
throw std::runtime_error("failed to initialize KV cache");
|
throw std::runtime_error("failed to initialize KV cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,7 +287,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve pp graph first so that buffers are only allocated once
|
// reserve pp graph first so that buffers are only allocated once
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
|
@ -299,7 +298,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve with tg graph to get the number of splits and nodes
|
// reserve with tg graph to get the number of splits and nodes
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(1, 1, 1, kv_state.get());
|
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||||
}
|
}
|
||||||
|
@ -310,7 +309,7 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
|
@ -419,14 +418,8 @@ uint32_t llama_context::n_threads_batch() const {
|
||||||
return cparams.n_threads_batch;
|
return cparams.n_threads_batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache * llama_context::get_kv_self() {
|
llama_memory_t llama_context::get_memory() const {
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
return memory.get();
|
||||||
return kv_self;
|
|
||||||
}
|
|
||||||
|
|
||||||
const llama_kv_cache * llama_context::get_kv_self() const {
|
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
||||||
return kv_self;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::kv_self_defrag_sched() {
|
void llama_context::kv_self_defrag_sched() {
|
||||||
|
@ -442,15 +435,13 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
||||||
|
|
||||||
{
|
{
|
||||||
// TODO: remove in the future
|
// TODO: remove in the future
|
||||||
optimize |= memory_force_optimize;
|
optimize |= memory_force_optimize;
|
||||||
memory_force_optimize = false;
|
memory_force_optimize = false;
|
||||||
|
|
||||||
const auto kv_state = kv_self->init_update(this, optimize);
|
const auto mstate = memory->init_update(this, optimize);
|
||||||
switch (kv_state->get_status()) {
|
switch (mstate->get_status()) {
|
||||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
{
|
{
|
||||||
// noop
|
// noop
|
||||||
|
@ -468,23 +459,25 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!kv_state->apply()) {
|
if (!mstate->apply()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
// if the memory module did any computation, we have to reserve a new worst-case graph
|
||||||
const auto kv_state = kv_self->init_full();
|
{
|
||||||
if (!kv_state) {
|
const auto mstate = memory->init_full();
|
||||||
throw std::runtime_error("failed to initialize memory state");
|
if (!mstate) {
|
||||||
}
|
throw std::runtime_error("failed to initialize memory state");
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_seqs = cparams.n_seq_max;
|
const uint32_t n_seqs = cparams.n_seq_max;
|
||||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -846,16 +839,17 @@ int llama_context::encode(llama_batch & inp_batch) {
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_RANK:
|
case LLAMA_POOLING_TYPE_RANK:
|
||||||
{
|
{
|
||||||
// extract the rerank score - a single float per sequence
|
// extract the rerank score - n_cls_out floats per sequence
|
||||||
auto & embd_seq_out = embd_seq;
|
auto & embd_seq_out = embd_seq;
|
||||||
|
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
embd_seq_out[seq_id].resize(1);
|
embd_seq_out[seq_id].resize(n_cls_out);
|
||||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
|
@ -912,10 +906,8 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
// temporary allocate memory for the input batch if needed
|
||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
|
|
||||||
|
@ -977,21 +969,21 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
// handle any pending defrags/shifts
|
// handle any pending defrags/shifts
|
||||||
kv_self_update(false);
|
kv_self_update(false);
|
||||||
|
|
||||||
llama_memory_state_ptr kv_state;
|
llama_memory_state_ptr mstate;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||||
if (!kv_state) {
|
if (!mstate) {
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (kv_state->get_status()) {
|
switch (mstate->get_status()) {
|
||||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
{
|
{
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
{
|
{
|
||||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
|
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
||||||
|
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
|
@ -1031,7 +1023,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
int64_t n_outputs_prev = 0;
|
int64_t n_outputs_prev = 0;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
const auto & ubatch = kv_state->get_ubatch();
|
const auto & ubatch = mstate->get_ubatch();
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
// count the outputs in this u_batch
|
||||||
{
|
{
|
||||||
|
@ -1054,11 +1046,14 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
|
|
||||||
ggml_status status;
|
ggml_status status;
|
||||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
|
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||||
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
|
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
|
||||||
|
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||||
|
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
||||||
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
const auto & seq_id = ubatch.seq_id[i][0];
|
const auto & seq_id = ubatch.seq_id[i][0];
|
||||||
|
@ -1073,7 +1068,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
|
|
||||||
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
||||||
|
|
||||||
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
|
memory->seq_rm(s, pos_min[s], -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (status) {
|
switch (status) {
|
||||||
|
@ -1167,7 +1162,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
n_outputs_prev += n_outputs;
|
n_outputs_prev += n_outputs;
|
||||||
} while (kv_state->next());
|
} while (mstate->next());
|
||||||
|
|
||||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||||
n_outputs = n_outputs_all;
|
n_outputs = n_outputs_all;
|
||||||
|
@ -1176,7 +1171,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||||
{
|
{
|
||||||
bool sorted_output = true;
|
bool sorted_output = true;
|
||||||
|
|
||||||
auto & out_ids = kv_state->out_ids();
|
auto & out_ids = mstate->out_ids();
|
||||||
|
|
||||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
||||||
|
|
||||||
|
@ -1844,11 +1839,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
if (memory != nullptr) {
|
||||||
|
|
||||||
if (kv_self != nullptr) {
|
|
||||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||||
kv_self->state_write(io);
|
memory->state_write(io);
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
|
@ -1935,9 +1928,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||||
if (memory) {
|
if (memory) {
|
||||||
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
||||||
|
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
memory->state_read(io);
|
||||||
|
|
||||||
kv_self->state_read(io);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
|
@ -1947,9 +1938,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
||||||
GGML_UNUSED(seq_id);
|
GGML_UNUSED(seq_id);
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
memory->state_write(io, seq_id);
|
||||||
|
|
||||||
kv_self->state_write(io, seq_id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
|
@ -1959,9 +1948,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
||||||
GGML_UNUSED(seq_id);
|
GGML_UNUSED(seq_id);
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
memory->state_read(io, seq_id);
|
||||||
|
|
||||||
kv_self->state_read(io, seq_id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
|
@ -2066,9 +2053,7 @@ void llama_context::opt_epoch_iter(
|
||||||
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
||||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||||
|
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
memory->clear();
|
||||||
|
|
||||||
kv_self->clear();
|
|
||||||
|
|
||||||
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
||||||
batch.n_tokens = n_batch;
|
batch.n_tokens = n_batch;
|
||||||
|
@ -2091,8 +2076,8 @@ void llama_context::opt_epoch_iter(
|
||||||
|
|
||||||
int64_t n_outputs_all = n_tokens_all;
|
int64_t n_outputs_all = n_tokens_all;
|
||||||
|
|
||||||
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
||||||
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -2105,17 +2090,17 @@ void llama_context::opt_epoch_iter(
|
||||||
|
|
||||||
uint32_t pos_batch = 0;
|
uint32_t pos_batch = 0;
|
||||||
do {
|
do {
|
||||||
const auto & ubatch = kv_state->get_ubatch();
|
const auto & ubatch = mstate->get_ubatch();
|
||||||
|
|
||||||
n_outputs = ubatch.n_tokens;
|
n_outputs = ubatch.n_tokens;
|
||||||
|
|
||||||
if (!kv_state->apply()) {
|
if (!mstate->apply()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
||||||
|
|
||||||
struct ggml_context * ctx_compute_opt;
|
struct ggml_context * ctx_compute_opt;
|
||||||
{
|
{
|
||||||
|
@ -2150,7 +2135,7 @@ void llama_context::opt_epoch_iter(
|
||||||
ggml_free(ctx_compute_opt);
|
ggml_free(ctx_compute_opt);
|
||||||
|
|
||||||
pos_batch += ubatch.n_tokens;
|
pos_batch += ubatch.n_tokens;
|
||||||
} while (kv_state->next());
|
} while (mstate->next());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2311,8 +2296,9 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
||||||
return &ctx->get_model();
|
return &ctx->get_model();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
||||||
return ctx->get_kv_self();
|
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2432,13 +2418,82 @@ int32_t llama_apply_adapter_cvec(
|
||||||
return res ? 0 : -1;
|
return res ? 0 : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// memory
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||||
|
return ctx->get_memory();
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_clear(llama_memory_t mem) {
|
||||||
|
mem->clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_seq_rm(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1) {
|
||||||
|
return mem->seq_rm(seq_id, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_seq_cp(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id_src,
|
||||||
|
llama_seq_id seq_id_dst,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1) {
|
||||||
|
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_seq_keep(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id) {
|
||||||
|
mem->seq_keep(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_seq_add(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
llama_pos delta) {
|
||||||
|
mem->seq_add(seq_id, p0, p1, delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_seq_div(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d) {
|
||||||
|
mem->seq_div(seq_id, p0, p1, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_memory_seq_pos_min(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id) {
|
||||||
|
return mem->seq_pos_min(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_memory_seq_pos_max(
|
||||||
|
llama_memory_t mem,
|
||||||
|
llama_seq_id seq_id) {
|
||||||
|
return mem->seq_pos_max(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_can_shift(llama_memory_t mem) {
|
||||||
|
return mem->get_can_shift();
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// kv cache
|
// kv cache
|
||||||
//
|
//
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||||
const auto * kv = ctx->get_kv_self();
|
const auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -2460,7 +2515,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||||
// deprecated
|
// deprecated
|
||||||
// note: this is the same as above - will be removed anyway, so it's ok
|
// note: this is the same as above - will be removed anyway, so it's ok
|
||||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||||
const auto * kv = ctx->get_kv_self();
|
const auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -2480,12 +2535,12 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_self_clear(llama_context * ctx) {
|
void llama_kv_self_clear(llama_context * ctx) {
|
||||||
auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->clear();
|
llama_memory_clear(kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_self_seq_rm(
|
bool llama_kv_self_seq_rm(
|
||||||
|
@ -2493,12 +2548,12 @@ bool llama_kv_self_seq_rm(
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1) {
|
llama_pos p1) {
|
||||||
auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv->seq_rm(seq_id, p0, p1);
|
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_self_seq_cp(
|
void llama_kv_self_seq_cp(
|
||||||
|
@ -2507,21 +2562,21 @@ void llama_kv_self_seq_cp(
|
||||||
llama_seq_id seq_id_dst,
|
llama_seq_id seq_id_dst,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1) {
|
llama_pos p1) {
|
||||||
auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->seq_keep(seq_id);
|
llama_memory_seq_keep(kv, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_self_seq_add(
|
void llama_kv_self_seq_add(
|
||||||
|
@ -2530,12 +2585,12 @@ void llama_kv_self_seq_add(
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
llama_pos delta) {
|
llama_pos delta) {
|
||||||
auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->seq_add(seq_id, p0, p1, delta);
|
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_self_seq_div(
|
void llama_kv_self_seq_div(
|
||||||
|
@ -2544,30 +2599,30 @@ void llama_kv_self_seq_div(
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
int d) {
|
int d) {
|
||||||
auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->seq_div(seq_id, p0, p1, d);
|
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
const auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv->seq_pos_min(seq_id);
|
return llama_memory_seq_pos_min(kv, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||||
const auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv->seq_pos_max(seq_id);
|
return llama_memory_seq_pos_max(kv, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
@ -2577,12 +2632,12 @@ void llama_kv_self_defrag(llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||||
const auto * kv = ctx->get_kv_self();
|
auto * kv = llama_get_memory(ctx);
|
||||||
if (!kv) {
|
if (!kv) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv->get_can_shift();
|
return llama_memory_can_shift(kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama state API
|
// llama state API
|
||||||
|
|
|
@ -13,13 +13,12 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_kv_cache;
|
|
||||||
|
|
||||||
class llama_io_read_i;
|
class llama_io_read_i;
|
||||||
class llama_io_write_i;
|
class llama_io_write_i;
|
||||||
|
|
||||||
class llama_memory_i;
|
struct llama_memory_i;
|
||||||
class llama_memory_state_i;
|
struct llama_memory_state_i;
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context {
|
||||||
// init scheduler and compute buffers, reserve worst-case graphs
|
// init scheduler and compute buffers, reserve worst-case graphs
|
||||||
|
@ -47,8 +46,7 @@ struct llama_context {
|
||||||
uint32_t n_threads() const;
|
uint32_t n_threads() const;
|
||||||
uint32_t n_threads_batch() const;
|
uint32_t n_threads_batch() const;
|
||||||
|
|
||||||
llama_kv_cache * get_kv_self();
|
llama_memory_t get_memory() const;
|
||||||
const llama_kv_cache * get_kv_self() const;
|
|
||||||
|
|
||||||
// return true of the KV cache was updated
|
// return true of the KV cache was updated
|
||||||
// TODO: remove
|
// TODO: remove
|
||||||
|
|
|
@ -17,7 +17,7 @@ struct ggml_tensor;
|
||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
struct llama_cparams;
|
struct llama_cparams;
|
||||||
|
|
||||||
class llama_memory_state_i;
|
struct llama_memory_state_i;
|
||||||
|
|
||||||
class llama_kv_cache_unified_state;
|
class llama_kv_cache_unified_state;
|
||||||
class llama_kv_cache_unified_iswa_state;
|
class llama_kv_cache_unified_iswa_state;
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
#include "llama-graph.h"
|
#include "llama-graph.h"
|
||||||
#include "llama-kv-cache.h"
|
#include "llama-memory.h"
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||||
class llama_kv_cache_recurrent : public llama_kv_cache {
|
class llama_kv_cache_recurrent : public llama_memory_i {
|
||||||
public:
|
public:
|
||||||
llama_kv_cache_recurrent(
|
llama_kv_cache_recurrent(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
|
@ -29,6 +29,16 @@ public:
|
||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_batch(
|
||||||
|
const llama_batch & batch,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
bool embd_pooled,
|
||||||
|
bool logits_all) override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
void clear() override;
|
void clear() override;
|
||||||
|
|
||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
@ -40,20 +50,6 @@ public:
|
||||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache
|
|
||||||
//
|
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
|
||||||
const llama_batch & batch,
|
|
||||||
uint32_t n_ubatch,
|
|
||||||
bool embd_pooled,
|
|
||||||
bool logits_all) override;
|
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
||||||
|
|
||||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
// find a contiguous slot of kv cells and emplace the ubatch there
|
// find a contiguous slot of kv cells and emplace the ubatch there
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
// utilizes two instances of llama_kv_cache_unified
|
// utilizes two instances of llama_kv_cache_unified
|
||||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||||
|
|
||||||
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
class llama_kv_cache_unified_iswa : public llama_memory_i {
|
||||||
public:
|
public:
|
||||||
llama_kv_cache_unified_iswa(
|
llama_kv_cache_unified_iswa(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
|
@ -31,21 +31,6 @@ public:
|
||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
void clear() override;
|
|
||||||
|
|
||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
||||||
void seq_keep(llama_seq_id seq_id) override;
|
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
|
||||||
|
|
||||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache
|
|
||||||
//
|
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_state_ptr init_batch(
|
||||||
const llama_batch & batch,
|
const llama_batch & batch,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
|
@ -58,6 +43,17 @@ public:
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
void clear() override;
|
||||||
|
|
||||||
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
#include "llama-graph.h"
|
#include "llama-graph.h"
|
||||||
#include "llama-kv-cache.h"
|
|
||||||
#include "llama-kv-cells.h"
|
#include "llama-kv-cells.h"
|
||||||
|
#include "llama-memory.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -17,7 +17,7 @@ struct llama_context;
|
||||||
// llama_kv_cache_unified
|
// llama_kv_cache_unified
|
||||||
//
|
//
|
||||||
|
|
||||||
class llama_kv_cache_unified : public llama_kv_cache {
|
class llama_kv_cache_unified : public llama_memory_i {
|
||||||
public:
|
public:
|
||||||
static uint32_t get_padding(const llama_cparams & cparams);
|
static uint32_t get_padding(const llama_cparams & cparams);
|
||||||
|
|
||||||
|
@ -56,21 +56,6 @@ public:
|
||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
void clear() override;
|
|
||||||
|
|
||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
||||||
void seq_keep(llama_seq_id seq_id) override;
|
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
|
||||||
|
|
||||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
||||||
|
|
||||||
//
|
|
||||||
// llama_kv_cache
|
|
||||||
//
|
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_state_ptr init_batch(
|
||||||
const llama_batch & batch,
|
const llama_batch & batch,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
|
@ -83,6 +68,17 @@ public:
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
void clear() override;
|
||||||
|
|
||||||
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
#include "llama-kv-cache.h"
|
|
|
@ -1,41 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "llama.h"
|
|
||||||
#include "llama-memory.h"
|
|
||||||
|
|
||||||
class llama_io_write_i;
|
|
||||||
class llama_io_read_i;
|
|
||||||
|
|
||||||
struct llama_kv_cache : public llama_memory_i {
|
|
||||||
virtual ~llama_kv_cache() = default;
|
|
||||||
|
|
||||||
// TODO: move the init_ interfaces to llama_memory_i
|
|
||||||
|
|
||||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
|
||||||
// return a state object containing the ubatches and KV cache state required to process them
|
|
||||||
// check the llama_memory_state_i::get_status() for the result
|
|
||||||
virtual llama_memory_state_ptr init_batch(
|
|
||||||
const llama_batch & batch,
|
|
||||||
uint32_t n_ubatch,
|
|
||||||
bool embd_pooled,
|
|
||||||
bool logits_all) = 0;
|
|
||||||
|
|
||||||
// simulate full cache, used for allocating worst-case compute buffers
|
|
||||||
virtual llama_memory_state_ptr init_full() = 0;
|
|
||||||
|
|
||||||
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
|
||||||
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
|
||||||
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
|
||||||
|
|
||||||
// getters
|
|
||||||
virtual bool get_can_shift() const = 0;
|
|
||||||
|
|
||||||
bool get_can_edit() const override { return get_can_shift(); }
|
|
||||||
|
|
||||||
//
|
|
||||||
// state write/read
|
|
||||||
//
|
|
||||||
|
|
||||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
|
||||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
|
||||||
};
|
|
|
@ -7,6 +7,9 @@
|
||||||
|
|
||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
|
|
||||||
|
class llama_io_write_i;
|
||||||
|
class llama_io_read_i;
|
||||||
|
|
||||||
struct llama_memory_params {
|
struct llama_memory_params {
|
||||||
// kv cache
|
// kv cache
|
||||||
ggml_type type_k;
|
ggml_type type_k;
|
||||||
|
@ -16,28 +19,6 @@ struct llama_memory_params {
|
||||||
bool swa_full;
|
bool swa_full;
|
||||||
};
|
};
|
||||||
|
|
||||||
// general concept of LLM memory
|
|
||||||
// the KV cache is a type of LLM memory, but there can be other types
|
|
||||||
class llama_memory_i {
|
|
||||||
public:
|
|
||||||
virtual ~llama_memory_i() = default;
|
|
||||||
|
|
||||||
virtual void clear() = 0;
|
|
||||||
|
|
||||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
|
||||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
|
||||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
|
||||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
|
||||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
|
||||||
|
|
||||||
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
|
||||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
|
||||||
|
|
||||||
virtual bool get_can_edit() const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
|
||||||
|
|
||||||
enum llama_memory_status {
|
enum llama_memory_status {
|
||||||
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
||||||
LLAMA_MEMORY_STATUS_NO_UPDATE,
|
LLAMA_MEMORY_STATUS_NO_UPDATE,
|
||||||
|
@ -58,8 +39,7 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
|
||||||
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
||||||
//
|
//
|
||||||
// TODO: rename to llama_memory_context_i ?
|
// TODO: rename to llama_memory_context_i ?
|
||||||
class llama_memory_state_i {
|
struct llama_memory_state_i {
|
||||||
public:
|
|
||||||
virtual ~llama_memory_state_i() = default;
|
virtual ~llama_memory_state_i() = default;
|
||||||
|
|
||||||
// consume the current ubatch from the state and proceed to the next one
|
// consume the current ubatch from the state and proceed to the next one
|
||||||
|
@ -81,3 +61,57 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
||||||
|
|
||||||
|
// general concept of LLM memory
|
||||||
|
// the KV cache is a type of LLM memory, but there can be other types
|
||||||
|
struct llama_memory_i {
|
||||||
|
virtual ~llama_memory_i() = default;
|
||||||
|
|
||||||
|
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||||
|
// return a state object containing the ubatches and KV cache state required to process them
|
||||||
|
// check the llama_memory_state_i::get_status() for the result
|
||||||
|
virtual llama_memory_state_ptr init_batch(
|
||||||
|
const llama_batch & batch,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
bool embd_pooled,
|
||||||
|
bool logits_all) = 0;
|
||||||
|
|
||||||
|
// simulate full cache, used for allocating worst-case compute buffers
|
||||||
|
virtual llama_memory_state_ptr init_full() = 0;
|
||||||
|
|
||||||
|
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||||
|
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||||
|
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||||
|
|
||||||
|
// getters
|
||||||
|
virtual bool get_can_shift() const = 0;
|
||||||
|
|
||||||
|
//
|
||||||
|
// ops
|
||||||
|
//
|
||||||
|
|
||||||
|
virtual void clear() = 0;
|
||||||
|
|
||||||
|
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||||
|
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||||
|
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||||
|
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
||||||
|
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||||
|
|
||||||
|
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
||||||
|
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
||||||
|
|
||||||
|
//
|
||||||
|
// state write/read
|
||||||
|
//
|
||||||
|
|
||||||
|
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||||
|
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||||
|
|
||||||
|
// TODO: temporary until the llama_kv_cache is removed from the public API
|
||||||
|
struct llama_kv_cache : public llama_memory_i {
|
||||||
|
virtual ~llama_kv_cache() = default;
|
||||||
|
};
|
||||||
|
|
|
@ -402,7 +402,7 @@ struct llama_mmap::impl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
throw std::runtime_error("PrefetchVirtualMemory unavailable");
|
LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -292,9 +292,10 @@ namespace GGUFMeta {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
|
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
|
||||||
const int kid = gguf_find_key(meta.get(), key.c_str());
|
const gguf_context * ctx = meta.get();
|
||||||
|
const int kid = gguf_find_key(ctx, key.c_str());
|
||||||
|
|
||||||
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
|
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
|
||||||
if (required) {
|
if (required) {
|
||||||
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
||||||
}
|
}
|
||||||
|
@ -302,28 +303,40 @@ namespace GGUFMeta {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GGUFMeta::ArrayInfo arr_info =
|
struct GGUFMeta::ArrayInfo arr_info =
|
||||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
|
||||||
|
|
||||||
switch (arr_info.gt) {
|
switch (arr_info.gt) {
|
||||||
case GGUF_TYPE_UINT32:
|
case GGUF_TYPE_UINT32:
|
||||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
||||||
(std::is_same<T, uint32_t>::value)); break;
|
(std::is_same<T, uint32_t>::value)); break;
|
||||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||||
|
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
|
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
result.resize(arr_info.length);
|
if constexpr (std::is_same<T, std::string>::value) {
|
||||||
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
|
const size_t n_items = gguf_get_arr_n(ctx, kid);
|
||||||
|
result.clear();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_items; i++) {
|
||||||
|
const T value = gguf_get_arr_str(ctx, kid, i);
|
||||||
|
result.emplace_back(value);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.resize(arr_info.length);
|
||||||
|
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, size_t N_MAX>
|
template<typename T, size_t N_MAX>
|
||||||
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
|
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
|
||||||
const int kid = gguf_find_key(meta.get(), key.c_str());
|
const gguf_context * ctx = meta.get();
|
||||||
|
const int kid = gguf_find_key(ctx, key.c_str());
|
||||||
|
|
||||||
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
|
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
|
||||||
if (required) {
|
if (required) {
|
||||||
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
|
||||||
}
|
}
|
||||||
|
@ -331,22 +344,32 @@ namespace GGUFMeta {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GGUFMeta::ArrayInfo arr_info =
|
struct GGUFMeta::ArrayInfo arr_info =
|
||||||
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
|
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
|
||||||
|
|
||||||
switch (arr_info.gt) {
|
switch (arr_info.gt) {
|
||||||
case GGUF_TYPE_UINT32:
|
case GGUF_TYPE_UINT32:
|
||||||
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
|
||||||
(std::is_same<T, uint32_t>::value)); break;
|
(std::is_same<T, uint32_t>::value)); break;
|
||||||
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
|
||||||
|
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
|
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arr_info.length > N_MAX) {
|
if (arr_info.length > N_MAX) {
|
||||||
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
|
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
if constexpr (std::is_same<T, std::string>::value) {
|
||||||
|
const size_t n_items = gguf_get_arr_n(ctx, kid);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_items; i++) {
|
||||||
|
const T value = gguf_get_arr_str(ctx, kid, i);
|
||||||
|
result[i] = value;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -356,6 +379,8 @@ namespace GGUFMeta {
|
||||||
return get_arr(llm_kv(kid), result, required);
|
return get_arr(llm_kv(kid), result, required);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
|
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
|
||||||
auto it = kv_overrides.find(key);
|
auto it = kv_overrides.find(key);
|
||||||
|
|
|
@ -548,6 +548,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
uint32_t n_vocab = 0;
|
uint32_t n_vocab = 0;
|
||||||
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
|
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
|
||||||
|
|
||||||
|
// for classifier models
|
||||||
|
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
|
||||||
|
if (!classifier_labels.empty()) {
|
||||||
|
hparams.n_cls_out = classifier_labels.size();
|
||||||
|
}
|
||||||
|
|
||||||
// arch-specific KVs
|
// arch-specific KVs
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case LLM_ARCH_LLAMA:
|
case LLM_ARCH_LLAMA:
|
||||||
|
@ -691,7 +697,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
||||||
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
|
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 3:
|
case 3:
|
||||||
|
@ -4459,6 +4464,15 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
||||||
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
||||||
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
|
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
|
||||||
|
|
||||||
|
if (!classifier_labels.empty()) {
|
||||||
|
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
for (auto label : classifier_labels) {
|
||||||
|
LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
|
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
|
||||||
|
@ -13702,6 +13716,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
|
||||||
return model->hparams.n_swa;
|
return model->hparams.n_swa;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
|
||||||
|
return model->hparams.n_cls_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
|
||||||
|
if (i < model->classifier_labels.size()) {
|
||||||
|
return model->classifier_labels[i].c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
int32_t llama_n_ctx_train(const llama_model * model) {
|
int32_t llama_n_ctx_train(const llama_model * model) {
|
||||||
return llama_model_n_ctx_train(model);
|
return llama_model_n_ctx_train(model);
|
||||||
|
|
|
@ -329,6 +329,9 @@ struct llama_model {
|
||||||
llama_hparams hparams = {};
|
llama_hparams hparams = {};
|
||||||
llama_vocab vocab;
|
llama_vocab vocab;
|
||||||
|
|
||||||
|
// for classifier models
|
||||||
|
std::vector<std::string> classifier_labels;
|
||||||
|
|
||||||
struct ggml_tensor * tok_embd = nullptr;
|
struct ggml_tensor * tok_embd = nullptr;
|
||||||
struct ggml_tensor * type_embd = nullptr;
|
struct ggml_tensor * type_embd = nullptr;
|
||||||
struct ggml_tensor * pos_embd = nullptr;
|
struct ggml_tensor * pos_embd = nullptr;
|
||||||
|
|
|
@ -2337,7 +2337,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|
||||||
|| _contains_any(general_arch, {"nomic-bert-moe"})
|
|| _contains_any(general_arch, {"nomic-bert-moe"})
|
||||||
) {
|
) {
|
||||||
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
if (token_to_id.count("<mask>") == 0) {
|
||||||
|
LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
|
||||||
|
} else {
|
||||||
|
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
||||||
|
}
|
||||||
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
||||||
for (auto id : cache_special_tokens) {
|
for (auto id : cache_special_tokens) {
|
||||||
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
|
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue