Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	README.md
#	ci/run.sh
#	examples/embedding/embedding.cpp
#	ggml/CMakeLists.txt
#	ggml/src/CMakeLists.txt
#	src/CMakeLists.txt
This commit is contained in:
Concedo 2025-06-06 17:56:51 +08:00
commit d33c88b1f4
18 changed files with 508 additions and 276 deletions

View file

@ -3709,8 +3709,7 @@ class BertModel(TextModel):
self._try_set_pooling_type() self._try_set_pooling_type()
if self.cls_out_labels: if self.cls_out_labels:
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch]) self.gguf_writer.add_classifier_output_labels([v for k, v in sorted(self.cls_out_labels.items())])
self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())])
def set_vocab(self): def set_vocab(self):
tokens, toktypes, tokpre = self.get_vocab_base() tokens, toktypes, tokpre = self.get_vocab_base()

View file

@ -212,6 +212,7 @@ enum vk_device_architecture {
AMD_RDNA1, AMD_RDNA1,
AMD_RDNA2, AMD_RDNA2,
AMD_RDNA3, AMD_RDNA3,
INTEL_XE2,
}; };
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
@ -262,6 +263,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
} }
return vk_device_architecture::AMD_RDNA2; return vk_device_architecture::AMD_RDNA2;
} }
} else if (props.vendorID == VK_VENDOR_ID_INTEL) {
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
bool subgroup_size_control = false;
for (const auto& properties : ext_props) {
if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
subgroup_size_control = true;
}
}
if (!subgroup_size_control) {
return vk_device_architecture::OTHER;
}
vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
props2.pNext = &subgroup_size_control_props;
device.getProperties2(&props2);
if (subgroup_size_control_props.minSubgroupSize == 16) {
// Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.
// Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.
// https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
return vk_device_architecture::INTEL_XE2;
}
} }
return vk_device_architecture::OTHER; return vk_device_architecture::OTHER;
} }
@ -4103,7 +4132,33 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
return s; return s;
} }
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) { template <typename T> size_t push_constant_size(const T &t) {
static_assert(std::is_class<T>::value, "T must be a struct/class");
GGML_UNUSED(t);
return sizeof(T);
}
template <typename T> size_t push_constant_size(const std::vector<T> &t) {
GGML_UNUSED(t);
return sizeof(T) * t.size();
}
template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {
GGML_UNUSED(t);
return sizeof(T) * N;
}
template <typename T> const T *push_constant_data(const T &t) {
static_assert(std::is_class<T>::value, "T must be a struct/class");
return &t;
}
template <typename T> const T *push_constant_data(const std::vector<T> &t) {
return t.data();
}
template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {
return t.data();
}
template <typename T>
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
@ -4119,7 +4174,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants); subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
pipeline->layout, pipeline->layout,
@ -4582,7 +4637,7 @@ static void ggml_vk_matmul(
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
if (split_k == 1) { if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
return; return;
} }
@ -4590,10 +4645,10 @@ static void ggml_vk_matmul(
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n }; const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work // Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k }; const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
} }
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
@ -4641,7 +4696,7 @@ static void ggml_vk_matmul_id(
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11, padded_n }; nei0, nei1, nbi1, ne11, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
} }
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@ -4762,7 +4817,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
}; };
init_pushconst_fastdiv(pc); init_pushconst_fastdiv(pc);
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
} }
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
@ -4781,7 +4836,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 }); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
} }
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@ -4981,7 +5036,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
} else if (qx_needs_dequant) { } else if (qx_needs_dequant) {
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
} }
if (y_non_contig) { if (y_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
@ -5197,7 +5252,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
} }
static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@ -5285,7 +5340,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
} }
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z }); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z });
} }
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@ -5368,7 +5423,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
} }
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@ -5584,7 +5639,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
} }
if (y_non_contig) { if (y_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
@ -5804,7 +5859,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); pc, { groups_x, (uint32_t)nei0, groups_z });
} }
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
@ -6154,7 +6209,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// there's no more than one tile of rows (i.e. workgroups_x would have been // there's no more than one tile of rows (i.e. workgroups_x would have been
// one). We reuse workgroups_x to mean the number of splits, so we need to // one). We reuse workgroups_x to mean the number of splits, so we need to
// cancel out the divide by wg_denoms[0]. // cancel out the divide by wg_denoms[0].
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k }; const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
@ -6163,7 +6218,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
}, },
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 }); pc2, { (uint32_t)ne1, 1, 1 });
} else { } else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ {
@ -6173,7 +6228,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
}, },
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z }); pc, { workgroups_x, workgroups_y, workgroups_z });
} }
} }
@ -6851,7 +6906,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} }
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer // Empty src2 is possible in rope, but the shader needs a buffer
vk_subbuffer subbuf_z; vk_subbuffer subbuf_z;
@ -6862,26 +6917,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} }
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_IM2COL) { } else if (op == GGML_OP_IM2COL) {
// im2col uses only src1 and dst buffers // im2col uses only src1 and dst buffers
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) { } else if (op == GGML_OP_COUNT_EQUAL) {
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
// count_equal assumes that destination buffer is initialized with zeroes // count_equal assumes that destination buffer is initialized with zeroes
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src2) { } else if (use_src2) {
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (use_src1) { } else if (use_src1) {
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else { } else {
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} }
} }
@ -7050,7 +7105,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_D, dst_offset, dst_size } vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); }, pc, elements);
} else if (version == 7) { } else if (version == 7) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
@ -7061,7 +7116,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size } vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements); }, pc, elements);
} else { } else {
// shouldn't happen // shouldn't happen
GGML_ASSERT(false); GGML_ASSERT(false);
@ -7198,7 +7253,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
vk_subbuffer{ d_GM, gm_offset, gm_size }, vk_subbuffer{ d_GM, gm_offset, gm_size },
vk_subbuffer{ d_GV, gv_offset, gv_size }, vk_subbuffer{ d_GV, gv_offset, gv_size },
vk_subbuffer{ d_P, p_offset, p_size }, vk_subbuffer{ d_P, p_offset, p_size },
}, sizeof(vk_op_push_constants), &pc, elements); }, pc, elements);
} }
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@ -8087,7 +8142,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
ggml_vk_ctx_begin(ctx->device, subctx); ggml_vk_ctx_begin(ctx->device, subctx);
const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});
ggml_vk_ctx_end(subctx); ggml_vk_ctx_end(subctx);
auto begin = std::chrono::high_resolution_clock::now(); auto begin = std::chrono::high_resolution_clock::now();
@ -10261,8 +10316,9 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
switch (props.vendorID) { switch (props.vendorID) {
case VK_VENDOR_ID_INTEL: case VK_VENDOR_ID_INTEL:
// Intel drivers don't support coopmat properly yet // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
return false; // while some older hardware (ex. Arc A770) has performance regressions
return arch == vk_device_architecture::INTEL_XE2;
case VK_VENDOR_ID_AMD: case VK_VENDOR_ID_AMD:
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
// Workaround for AMD proprietary driver reporting support on all GPUs // Workaround for AMD proprietary driver reporting support on all GPUs

View file

@ -935,6 +935,9 @@ class GGUFWriter:
def add_eom_token_id(self, id: int) -> None: def add_eom_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOM_ID, id) self.add_uint32(Keys.Tokenizer.EOM_ID, id)
def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
# for vision models # for vision models
def add_clip_has_vision_encoder(self, value: bool) -> None: def add_clip_has_vision_encoder(self, value: bool) -> None:

View file

@ -64,7 +64,10 @@ extern "C" {
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_sampler; struct llama_sampler;
struct llama_kv_cache;
typedef struct llama_memory_i * llama_memory_t;
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
typedef int32_t llama_pos; typedef int32_t llama_pos;
typedef int32_t llama_token; typedef int32_t llama_token;
@ -496,9 +499,11 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx); LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@ -512,6 +517,13 @@ extern "C" {
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
// Returns the number of classifier outputs (only valid for classifier models)
// Undefined behavior for non-classifier models
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
@ -612,7 +624,78 @@ extern "C" {
int32_t il_end); int32_t il_end);
// //
// KV cache // Memory
//
// Clear the memory contents
LLAMA_API void llama_memory_clear(llama_memory_t mem);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
// Copy all tokens that belong to the specified sequence to another sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
// Integer division of the positions by factor of `d > 1`
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
// Returns the smallest position present in the memory for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id);
// Returns the largest position present in the memory for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id);
// Check if the memory supports shifting
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
//
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
// //
// Returns the number of tokens in the KV cache (slow, use only for debug) // Returns the number of tokens in the KV cache (slow, use only for debug)
@ -626,7 +709,7 @@ extern "C" {
// Clear the KV cache - both cell info is erased and KV data is zeroed // Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear( LLAMA_API void llama_kv_self_clear(
struct llama_context * ctx); struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
@ -697,14 +780,14 @@ extern "C" {
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied:
// - lazily on next llama_decode() // - lazily on next llama_decode()
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx), DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
// Check if the context supports KV cache shifting // Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx), DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
"simply remove this call, updates are applied lazily on the next llama_decode()"); "simply remove this call, updates are applied lazily on the next llama_decode()");
// //
@ -712,7 +795,7 @@ extern "C" {
// //
// Returns the *actual* size in bytes of the state // Returns the *actual* size in bytes of the state
// (logits, embedding and kv_cache) // (logits, embedding and memory)
// Only use when saving the state, not when restoring it, otherwise the size may be too small. // Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@ -768,12 +851,12 @@ extern "C" {
size_t n_token_count), size_t n_token_count),
"use llama_state_save_file instead"); "use llama_state_save_file instead");
// Get the exact size needed to copy the KV cache of a single sequence // Get the exact size needed to copy the state of a single sequence
LLAMA_API size_t llama_state_seq_get_size( LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
// Copy the KV cache of a single sequence into the specified buffer // Copy the state of a single sequence into the specified buffer
LLAMA_API size_t llama_state_seq_get_data( LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx, struct llama_context * ctx,
uint8_t * dst, uint8_t * dst,
@ -839,16 +922,16 @@ extern "C" {
// For encode-decoder contexts, processes the batch using the encoder. // For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers. // Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success // 0 - success
// < 0 - error. the KV cache state is restored to the state before this call // < 0 - error. the memory state is restored to the state before this call
LLAMA_API int32_t llama_encode( LLAMA_API int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch); struct llama_batch batch);
// Process a batch of tokens. // Process a batch of tokens.
// Requires KV cache. // Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder. // For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// Upon non-zero return values, the KV cache state is restored to the state before this call // Upon non-zero return values, the memory state is restored to the state before this call
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted // 2 - aborted
@ -919,7 +1002,7 @@ extern "C" {
// Get the embeddings for a sequence id // Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
// otherwise: float[n_embd] (1-dimensional) // otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);

View file

@ -2,9 +2,9 @@
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-io.h" #include "llama-io.h"
#include "llama-memory.h"
#include "llama-mmap.h" #include "llama-mmap.h"
#include "llama-model.h" #include "llama-model.h"
#include "llama-kv-cache.h"
#include <cinttypes> #include <cinttypes>
#include <cstring> #include <cstring>
@ -277,10 +277,9 @@ llama_context::llama_context(
int n_nodes_tg = -1; int n_nodes_tg = -1;
// simulate full KV cache // simulate full KV cache
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
const auto kv_state = kv_self->init_full(); const auto mstate = memory->init_full();
if (!kv_state) { if (!mstate) {
throw std::runtime_error("failed to initialize KV cache"); throw std::runtime_error("failed to initialize KV cache");
} }
@ -288,7 +287,7 @@ llama_context::llama_context(
// reserve pp graph first so that buffers are only allocated once // reserve pp graph first so that buffers are only allocated once
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -299,7 +298,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes // reserve with tg graph to get the number of splits and nodes
{ {
auto * gf = graph_reserve(1, 1, 1, kv_state.get()); auto * gf = graph_reserve(1, 1, 1, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers"); throw std::runtime_error("failed to allocate compute tg buffers");
} }
@ -310,7 +309,7 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference // reserve again with pp graph to avoid ggml-alloc reallocations during inference
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -419,14 +418,8 @@ uint32_t llama_context::n_threads_batch() const {
return cparams.n_threads_batch; return cparams.n_threads_batch;
} }
llama_kv_cache * llama_context::get_kv_self() { llama_memory_t llama_context::get_memory() const {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); return memory.get();
return kv_self;
}
const llama_kv_cache * llama_context::get_kv_self() const {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
return kv_self;
} }
void llama_context::kv_self_defrag_sched() { void llama_context::kv_self_defrag_sched() {
@ -442,15 +435,13 @@ bool llama_context::kv_self_update(bool optimize) {
return false; return false;
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
{ {
// TODO: remove in the future // TODO: remove in the future
optimize |= memory_force_optimize; optimize |= memory_force_optimize;
memory_force_optimize = false; memory_force_optimize = false;
const auto kv_state = kv_self->init_update(this, optimize); const auto mstate = memory->init_update(this, optimize);
switch (kv_state->get_status()) { switch (mstate->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
// noop // noop
@ -468,23 +459,25 @@ bool llama_context::kv_self_update(bool optimize) {
} }
} }
if (!kv_state->apply()) { if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
} }
} }
// if the KV cache did any computation, we have to reserve a new worst-case graph // if the memory module did any computation, we have to reserve a new worst-case graph
const auto kv_state = kv_self->init_full(); {
if (!kv_state) { const auto mstate = memory->init_full();
throw std::runtime_error("failed to initialize memory state"); if (!mstate) {
} throw std::runtime_error("failed to initialize memory state");
}
const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
} }
return true; return true;
@ -846,16 +839,17 @@ int llama_context::encode(llama_batch & inp_batch) {
} break; } break;
case LLAMA_POOLING_TYPE_RANK: case LLAMA_POOLING_TYPE_RANK:
{ {
// extract the rerank score - a single float per sequence // extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq; auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0]; const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue; continue;
} }
embd_seq_out[seq_id].resize(1); embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_UNSPECIFIED: case LLAMA_POOLING_TYPE_UNSPECIFIED:
@ -912,10 +906,8 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
@ -977,21 +969,21 @@ int llama_context::decode(llama_batch & inp_batch) {
// handle any pending defrags/shifts // handle any pending defrags/shifts
kv_self_update(false); kv_self_update(false);
llama_memory_state_ptr kv_state; llama_memory_state_ptr mstate;
while (true) { while (true) {
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) { if (!mstate) {
return -2; return -2;
} }
switch (kv_state->get_status()) { switch (mstate->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
} break; } break;
case LLAMA_MEMORY_STATUS_NO_UPDATE: case LLAMA_MEMORY_STATUS_NO_UPDATE:
{ {
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status()); LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
return -2; return -2;
} }
@ -1031,7 +1023,7 @@ int llama_context::decode(llama_batch & inp_batch) {
int64_t n_outputs_prev = 0; int64_t n_outputs_prev = 0;
do { do {
const auto & ubatch = kv_state->get_ubatch(); const auto & ubatch = mstate->get_ubatch();
// count the outputs in this u_batch // count the outputs in this u_batch
{ {
@ -1054,11 +1046,14 @@ int llama_context::decode(llama_batch & inp_batch) {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status; ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
if (!res) { if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() }; llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0]; const auto & seq_id = ubatch.seq_id[i][0];
@ -1073,7 +1068,7 @@ int llama_context::decode(llama_batch & inp_batch) {
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
llama_kv_self_seq_rm(this, s, pos_min[s], -1); memory->seq_rm(s, pos_min[s], -1);
} }
switch (status) { switch (status) {
@ -1167,7 +1162,7 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
n_outputs_prev += n_outputs; n_outputs_prev += n_outputs;
} while (kv_state->next()); } while (mstate->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith // set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all; n_outputs = n_outputs_all;
@ -1176,7 +1171,7 @@ int llama_context::decode(llama_batch & inp_batch) {
{ {
bool sorted_output = true; bool sorted_output = true;
auto & out_ids = kv_state->out_ids(); auto & out_ids = mstate->out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
@ -1844,11 +1839,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
} }
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); if (memory != nullptr) {
if (kv_self != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
kv_self->state_write(io); memory->state_write(io);
} }
return io.n_bytes(); return io.n_bytes();
@ -1935,9 +1928,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
if (memory) { if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_read(io);
kv_self->state_read(io);
} }
return io.n_bytes(); return io.n_bytes();
@ -1947,9 +1938,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
if (memory) { if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_write(io, seq_id);
kv_self->state_write(io, seq_id);
} }
return io.n_bytes(); return io.n_bytes();
@ -1959,9 +1948,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
if (memory) { if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_read(io, seq_id);
kv_self->state_read(io, seq_id);
} }
return io.n_bytes(); return io.n_bytes();
@ -2066,9 +2053,7 @@ void llama_context::opt_epoch_iter(
const uint32_t n_batch = std::min(this->n_batch(), n_ctx); const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->clear();
kv_self->clear();
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
batch.n_tokens = n_batch; batch.n_tokens = n_batch;
@ -2091,8 +2076,8 @@ void llama_context::opt_epoch_iter(
int64_t n_outputs_all = n_tokens_all; int64_t n_outputs_all = n_tokens_all;
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break; break;
} }
@ -2105,17 +2090,17 @@ void llama_context::opt_epoch_iter(
uint32_t pos_batch = 0; uint32_t pos_batch = 0;
do { do {
const auto & ubatch = kv_state->get_ubatch(); const auto & ubatch = mstate->get_ubatch();
n_outputs = ubatch.n_tokens; n_outputs = ubatch.n_tokens;
if (!kv_state->apply()) { if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
break; break;
} }
auto * gf = graph_init(); auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
struct ggml_context * ctx_compute_opt; struct ggml_context * ctx_compute_opt;
{ {
@ -2150,7 +2135,7 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt); ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens; pos_batch += ubatch.n_tokens;
} while (kv_state->next()); } while (mstate->next());
} }
} }
@ -2311,8 +2296,9 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model(); return &ctx->get_model();
} }
// deprecated
llama_kv_cache * llama_get_kv_self(llama_context * ctx) { llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
return ctx->get_kv_self(); return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
} }
// deprecated // deprecated
@ -2432,13 +2418,82 @@ int32_t llama_apply_adapter_cvec(
return res ? 0 : -1; return res ? 0 : -1;
} }
//
// memory
//
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}
void llama_memory_clear(llama_memory_t mem) {
mem->clear();
}
bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return mem->seq_rm(seq_id, p0, p1);
}
void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id) {
mem->seq_keep(seq_id);
}
void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
mem->seq_add(seq_id, p0, p1, delta);
}
void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
mem->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id) {
return mem->seq_pos_min(seq_id);
}
llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id) {
return mem->seq_pos_max(seq_id);
}
bool llama_memory_can_shift(llama_memory_t mem) {
return mem->get_can_shift();
}
// //
// kv cache // kv cache
// //
// deprecated // deprecated
int32_t llama_kv_self_n_tokens(const llama_context * ctx) { int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); const auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return 0; return 0;
} }
@ -2460,7 +2515,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
// deprecated // deprecated
// note: this is the same as above - will be removed anyway, so it's ok // note: this is the same as above - will be removed anyway, so it's ok
int32_t llama_kv_self_used_cells(const llama_context * ctx) { int32_t llama_kv_self_used_cells(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); const auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return 0; return 0;
} }
@ -2480,12 +2535,12 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
} }
void llama_kv_self_clear(llama_context * ctx) { void llama_kv_self_clear(llama_context * ctx) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->clear(); llama_memory_clear(kv);
} }
bool llama_kv_self_seq_rm( bool llama_kv_self_seq_rm(
@ -2493,12 +2548,12 @@ bool llama_kv_self_seq_rm(
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return true; return true;
} }
return kv->seq_rm(seq_id, p0, p1); return llama_memory_seq_rm(kv, seq_id, p0, p1);
} }
void llama_kv_self_seq_cp( void llama_kv_self_seq_cp(
@ -2507,21 +2562,21 @@ void llama_kv_self_seq_cp(
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
} }
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_keep(seq_id); llama_memory_seq_keep(kv, seq_id);
} }
void llama_kv_self_seq_add( void llama_kv_self_seq_add(
@ -2530,12 +2585,12 @@ void llama_kv_self_seq_add(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta) { llama_pos delta) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_add(seq_id, p0, p1, delta); llama_memory_seq_add(kv, seq_id, p0, p1, delta);
} }
void llama_kv_self_seq_div( void llama_kv_self_seq_div(
@ -2544,30 +2599,30 @@ void llama_kv_self_seq_div(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d) { int d) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_div(seq_id, p0, p1, d); llama_memory_seq_div(kv, seq_id, p0, p1, d);
} }
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return -1; return -1;
} }
return kv->seq_pos_min(seq_id); return llama_memory_seq_pos_min(kv, seq_id);
} }
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return -1; return -1;
} }
return kv->seq_pos_max(seq_id); return llama_memory_seq_pos_max(kv, seq_id);
} }
// deprecated // deprecated
@ -2577,12 +2632,12 @@ void llama_kv_self_defrag(llama_context * ctx) {
} }
bool llama_kv_self_can_shift(const llama_context * ctx) { bool llama_kv_self_can_shift(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return false; return false;
} }
return kv->get_can_shift(); return llama_memory_can_shift(kv);
} }
// llama state API // llama state API

View file

@ -13,13 +13,12 @@
#include <vector> #include <vector>
struct llama_model; struct llama_model;
struct llama_kv_cache;
class llama_io_read_i; class llama_io_read_i;
class llama_io_write_i; class llama_io_write_i;
class llama_memory_i; struct llama_memory_i;
class llama_memory_state_i; struct llama_memory_state_i;
struct llama_context { struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs // init scheduler and compute buffers, reserve worst-case graphs
@ -47,8 +46,7 @@ struct llama_context {
uint32_t n_threads() const; uint32_t n_threads() const;
uint32_t n_threads_batch() const; uint32_t n_threads_batch() const;
llama_kv_cache * get_kv_self(); llama_memory_t get_memory() const;
const llama_kv_cache * get_kv_self() const;
// return true of the KV cache was updated // return true of the KV cache was updated
// TODO: remove // TODO: remove

View file

@ -17,7 +17,7 @@ struct ggml_tensor;
struct llama_ubatch; struct llama_ubatch;
struct llama_cparams; struct llama_cparams;
class llama_memory_state_i; struct llama_memory_state_i;
class llama_kv_cache_unified_state; class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state; class llama_kv_cache_unified_iswa_state;

View file

@ -2,7 +2,7 @@
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-kv-cache.h" #include "llama-memory.h"
#include <set> #include <set>
#include <vector> #include <vector>
@ -13,7 +13,7 @@
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
class llama_kv_cache_recurrent : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_memory_i {
public: public:
llama_kv_cache_recurrent( llama_kv_cache_recurrent(
const llama_model & model, const llama_model & model,
@ -29,6 +29,16 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
void clear() override; void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
@ -40,20 +50,6 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool prepare(const std::vector<llama_ubatch> & ubatches); bool prepare(const std::vector<llama_ubatch> & ubatches);
// find a contiguous slot of kv cells and emplace the ubatch there // find a contiguous slot of kv cells and emplace the ubatch there

View file

@ -11,7 +11,7 @@
// utilizes two instances of llama_kv_cache_unified // utilizes two instances of llama_kv_cache_unified
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
class llama_kv_cache_unified_iswa : public llama_kv_cache { class llama_kv_cache_unified_iswa : public llama_memory_i {
public: public:
llama_kv_cache_unified_iswa( llama_kv_cache_unified_iswa(
const llama_model & model, const llama_model & model,
@ -31,21 +31,6 @@ public:
// llama_memory_i // llama_memory_i
// //
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch( llama_memory_state_ptr init_batch(
const llama_batch & batch, const llama_batch & batch,
uint32_t n_ubatch, uint32_t n_ubatch,
@ -58,6 +43,17 @@ public:
bool get_can_shift() const override; bool get_can_shift() const override;
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;

View file

@ -2,8 +2,8 @@
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-kv-cache.h"
#include "llama-kv-cells.h" #include "llama-kv-cells.h"
#include "llama-memory.h"
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
@ -17,7 +17,7 @@ struct llama_context;
// llama_kv_cache_unified // llama_kv_cache_unified
// //
class llama_kv_cache_unified : public llama_kv_cache { class llama_kv_cache_unified : public llama_memory_i {
public: public:
static uint32_t get_padding(const llama_cparams & cparams); static uint32_t get_padding(const llama_cparams & cparams);
@ -56,21 +56,6 @@ public:
// llama_memory_i // llama_memory_i
// //
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch( llama_memory_state_ptr init_batch(
const llama_batch & batch, const llama_batch & batch,
uint32_t n_ubatch, uint32_t n_ubatch,
@ -83,6 +68,17 @@ public:
bool get_can_shift() const override; bool get_can_shift() const override;
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;

View file

@ -1 +0,0 @@
#include "llama-kv-cache.h"

View file

@ -1,41 +0,0 @@
#pragma once
#include "llama.h"
#include "llama-memory.h"
class llama_io_write_i;
class llama_io_read_i;
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
// TODO: move the init_ interfaces to llama_memory_i
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};

View file

@ -7,6 +7,9 @@
struct llama_ubatch; struct llama_ubatch;
class llama_io_write_i;
class llama_io_read_i;
struct llama_memory_params { struct llama_memory_params {
// kv cache // kv cache
ggml_type type_k; ggml_type type_k;
@ -16,28 +19,6 @@ struct llama_memory_params {
bool swa_full; bool swa_full;
}; };
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i {
public:
virtual ~llama_memory_i() = default;
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
virtual bool get_can_edit() const = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
enum llama_memory_status { enum llama_memory_status {
LLAMA_MEMORY_STATUS_SUCCESS = 0, LLAMA_MEMORY_STATUS_SUCCESS = 0,
LLAMA_MEMORY_STATUS_NO_UPDATE, LLAMA_MEMORY_STATUS_NO_UPDATE,
@ -58,8 +39,7 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
// the only method that can mutate the memory and the memory state is llama_memory_i::apply() // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
// //
// TODO: rename to llama_memory_context_i ? // TODO: rename to llama_memory_context_i ?
class llama_memory_state_i { struct llama_memory_state_i {
public:
virtual ~llama_memory_state_i() = default; virtual ~llama_memory_state_i() = default;
// consume the current ubatch from the state and proceed to the next one // consume the current ubatch from the state and proceed to the next one
@ -81,3 +61,57 @@ public:
}; };
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>; using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
struct llama_memory_i {
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters
virtual bool get_can_shift() const = 0;
//
// ops
//
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
// TODO: temporary until the llama_kv_cache is removed from the public API
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
};

View file

@ -402,7 +402,7 @@ struct llama_mmap::impl {
} }
} }
#else #else
throw std::runtime_error("PrefetchVirtualMemory unavailable"); LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
#endif #endif
} }
#else #else

View file

@ -292,9 +292,10 @@ namespace GGUFMeta {
template<typename T> template<typename T>
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) { bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
const int kid = gguf_find_key(meta.get(), key.c_str()); const gguf_context * ctx = meta.get();
const int kid = gguf_find_key(ctx, key.c_str());
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
if (required) { if (required) {
throw std::runtime_error(format("array key not found in model: %s", key.c_str())); throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
} }
@ -302,28 +303,40 @@ namespace GGUFMeta {
} }
struct GGUFMeta::ArrayInfo arr_info = struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid); GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
switch (arr_info.gt) { switch (arr_info.gt) {
case GGUF_TYPE_UINT32: case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) || case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break; (std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break; case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
default: default:
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
} }
result.resize(arr_info.length); if constexpr (std::is_same<T, std::string>::value) {
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); const size_t n_items = gguf_get_arr_n(ctx, kid);
result.clear();
for (size_t i = 0; i < n_items; i++) {
const T value = gguf_get_arr_str(ctx, kid, i);
result.emplace_back(value);
}
} else {
result.resize(arr_info.length);
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
}
return true; return true;
} }
template<typename T, size_t N_MAX> template<typename T, size_t N_MAX>
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) { bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
const int kid = gguf_find_key(meta.get(), key.c_str()); const gguf_context * ctx = meta.get();
const int kid = gguf_find_key(ctx, key.c_str());
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
if (required) { if (required) {
throw std::runtime_error(format("array key not found in model: %s", key.c_str())); throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
} }
@ -331,22 +344,32 @@ namespace GGUFMeta {
} }
struct GGUFMeta::ArrayInfo arr_info = struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid); GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
switch (arr_info.gt) { switch (arr_info.gt) {
case GGUF_TYPE_UINT32: case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) || case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break; (std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break; case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
default: default:
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
} }
if (arr_info.length > N_MAX) { if (arr_info.length > N_MAX) {
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
} }
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); if constexpr (std::is_same<T, std::string>::value) {
const size_t n_items = gguf_get_arr_n(ctx, kid);
for (size_t i = 0; i < n_items; i++) {
const T value = gguf_get_arr_str(ctx, kid, i);
result[i] = value;
}
} else {
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
}
return true; return true;
} }
@ -356,6 +379,8 @@ namespace GGUFMeta {
return get_arr(llm_kv(kid), result, required); return get_arr(llm_kv(kid), result, required);
} }
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
template<typename T> template<typename T>
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
auto it = kv_overrides.find(key); auto it = kv_overrides.find(key);

View file

@ -548,6 +548,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
uint32_t n_vocab = 0; uint32_t n_vocab = 0;
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
// for classifier models
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
if (!classifier_labels.empty()) {
hparams.n_cls_out = classifier_labels.size();
}
// arch-specific KVs // arch-specific KVs
switch (arch) { switch (arch) {
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
@ -691,7 +697,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 3: case 3:
@ -4459,6 +4464,15 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
if (!classifier_labels.empty()) {
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
size_t i = 0;
for (auto label : classifier_labels) {
LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
}
}
} }
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
@ -13702,6 +13716,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
return model->hparams.n_swa; return model->hparams.n_swa;
} }
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
return model->hparams.n_cls_out;
}
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
if (i < model->classifier_labels.size()) {
return model->classifier_labels[i].c_str();
}
return nullptr;
}
// deprecated // deprecated
int32_t llama_n_ctx_train(const llama_model * model) { int32_t llama_n_ctx_train(const llama_model * model) {
return llama_model_n_ctx_train(model); return llama_model_n_ctx_train(model);

View file

@ -329,6 +329,9 @@ struct llama_model {
llama_hparams hparams = {}; llama_hparams hparams = {};
llama_vocab vocab; llama_vocab vocab;
// for classifier models
std::vector<std::string> classifier_labels;
struct ggml_tensor * tok_embd = nullptr; struct ggml_tensor * tok_embd = nullptr;
struct ggml_tensor * type_embd = nullptr; struct ggml_tensor * type_embd = nullptr;
struct ggml_tensor * pos_embd = nullptr; struct ggml_tensor * pos_embd = nullptr;

View file

@ -2337,7 +2337,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"}) || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|| _contains_any(general_arch, {"nomic-bert-moe"}) || _contains_any(general_arch, {"nomic-bert-moe"})
) { ) {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true); if (token_to_id.count("<mask>") == 0) {
LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
} else {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
}
} else if (_contains_any(model_name, {"phi-3", "phi3"})) { } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
for (auto id : cache_special_tokens) { for (auto id : cache_special_tokens) {
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true); _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);