mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
SYCL: reduce allocation overhead during flash attention (#22732)
* SYCL: reduce allocation overhead during flash attention * tidy up whitespace * add a note about the flag * move ggml_sycl_fattn_* into fattn-buffers.hpp * refactor implementation into fattn-buffers.cpp * move new_fattn_kv_buffers back into ggml-sycl.cpp
This commit is contained in:
parent
fd89556567
commit
e20b83930c
6 changed files with 188 additions and 2 deletions
|
|
@ -737,6 +737,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
|
||||
|
||||
## Compile-time Flags
|
||||
|
||||
Pass these via `CXXFLAGS` or add a one-off `#define` to enable a flag on the spot.
|
||||
|
||||
| Name | Function |
|
||||
|-----------------|----------------------------------------------------------------------------------|
|
||||
| DEBUG_SYCL_POOL | Enable device memory pool logging on teardown. Useful for profiling allocations. |
|
||||
|
||||
## Design Rule
|
||||
|
||||
- Open to all contributors.
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
#include "presets.hpp"
|
||||
#include "type.hpp"
|
||||
#include "sycl_hw.hpp"
|
||||
#include "fattn-buffers.hpp"
|
||||
|
||||
namespace syclexp = sycl::ext::oneapi::experimental;
|
||||
|
||||
|
|
@ -404,12 +405,16 @@ struct ggml_backend_sycl_context {
|
|||
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
||||
std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
|
||||
|
||||
std::unique_ptr<ggml_sycl_fattn_kv_buffers> fattn_bufs[GGML_SYCL_MAX_DEVICES];
|
||||
|
||||
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
||||
|
||||
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
|
||||
|
||||
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
|
||||
|
||||
static std::unique_ptr<ggml_sycl_fattn_kv_buffers> new_fattn_kv_buffers(queue_ptr qptr, int device);
|
||||
|
||||
ggml_sycl_pool & pool(int device) {
|
||||
if (pools[device] == nullptr) {
|
||||
pools[device] = new_pool_for_device(stream(device,0), device);
|
||||
|
|
@ -421,6 +426,17 @@ struct ggml_backend_sycl_context {
|
|||
return pool(device);
|
||||
}
|
||||
|
||||
ggml_sycl_fattn_kv_buffers & fattn_buffers(int device) {
|
||||
if (fattn_bufs[device] == nullptr) {
|
||||
fattn_bufs[device] = new_fattn_kv_buffers(stream(device, 0), device);
|
||||
}
|
||||
return *fattn_bufs[device];
|
||||
}
|
||||
|
||||
ggml_sycl_fattn_kv_buffers & fattn_buffers() {
|
||||
return fattn_buffers(device);
|
||||
}
|
||||
|
||||
#ifdef GGML_SYCL_GRAPH
|
||||
std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
|
||||
#endif
|
||||
|
|
|
|||
56
ggml/src/ggml-sycl/fattn-buffers.cpp
Normal file
56
ggml/src/ggml-sycl/fattn-buffers.cpp
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
sycl::half * ggml_sycl_fattn_kv_buffers::kv_buffer::ensure_half(size_t n_elems) {
|
||||
const size_t need_bytes = n_elems * sizeof(sycl::half);
|
||||
|
||||
if (capacity >= need_bytes) {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
if (ptr) {
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(qptr->wait()));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
|
||||
ptr = nullptr;
|
||||
capacity = 0;
|
||||
}
|
||||
|
||||
size_t cap = 0;
|
||||
while (cap < need_bytes) {
|
||||
cap += CHUNK_SIZE;
|
||||
}
|
||||
|
||||
void * dev_ptr;
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(dev_ptr = sycl::malloc_device(
|
||||
cap, *qptr)));
|
||||
|
||||
if (!dev_ptr) {
|
||||
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, cap);
|
||||
GGML_ABORT("fattn buffer alloc failed");
|
||||
}
|
||||
|
||||
ptr = static_cast<sycl::half *>(dev_ptr);
|
||||
capacity = cap;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
ggml_sycl_fattn_kv_buffers::kv_buffer::~kv_buffer() {
|
||||
#ifdef DEBUG_SYCL_POOL
|
||||
GGML_LOG_INFO("ggml_sycl_fattn_kv_buffer[%d]: %.2f MiB\n", device, capacity / 1024.0 / 1024.0);
|
||||
#endif
|
||||
if (ptr) {
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
|
||||
}
|
||||
}
|
||||
63
ggml/src/ggml-sycl/fattn-buffers.hpp
Normal file
63
ggml/src/ggml-sycl/fattn-buffers.hpp
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_FATTN_BUFFERS_HPP
|
||||
#define GGML_SYCL_FATTN_BUFFERS_HPP
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
typedef sycl::queue *queue_ptr;
|
||||
|
||||
struct ggml_sycl_fattn_kv_buffers {
|
||||
// buffers grow in chunks of this size
|
||||
static constexpr size_t CHUNK_SIZE = 16ull << 20; // 16 MiB
|
||||
|
||||
struct kv_buffer {
|
||||
kv_buffer(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
|
||||
~kv_buffer();
|
||||
|
||||
kv_buffer(const kv_buffer &) = delete;
|
||||
kv_buffer & operator=(const kv_buffer &) = delete;
|
||||
|
||||
sycl::half * ensure_half(size_t n_elems);
|
||||
|
||||
private:
|
||||
sycl::half * ptr = nullptr;
|
||||
size_t capacity = 0;
|
||||
queue_ptr qptr = nullptr;
|
||||
[[maybe_unused]] int device = 0;
|
||||
};
|
||||
|
||||
kv_buffer K;
|
||||
kv_buffer V;
|
||||
|
||||
ggml_sycl_fattn_kv_buffers(queue_ptr qptr, int device) : K(qptr, device), V(qptr, device) {}
|
||||
|
||||
ggml_sycl_fattn_kv_buffers(const ggml_sycl_fattn_kv_buffers &) = delete;
|
||||
ggml_sycl_fattn_kv_buffers & operator=(const ggml_sycl_fattn_kv_buffers &) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* Imitates `ggml_sycl_pool_alloc` to keep the code calling alloc unchanged.
|
||||
*/
|
||||
struct ggml_sycl_fattn_alloc {
|
||||
ggml_sycl_fattn_kv_buffers::kv_buffer & buf;
|
||||
sycl::half * ptr = nullptr;
|
||||
|
||||
explicit ggml_sycl_fattn_alloc(ggml_sycl_fattn_kv_buffers::kv_buffer & buf_) : buf(buf_) {}
|
||||
|
||||
sycl::half * alloc(size_t n_elems) {
|
||||
ptr = buf.ensure_half(n_elems);
|
||||
return ptr;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
#include "common.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "vecdotq.hpp"
|
||||
#include "fattn-buffers.hpp"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
|
|
@ -918,12 +919,13 @@ void launch_fattn(
|
|||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
|
||||
ggml_sycl_pool & pool = ctx.pool();
|
||||
ggml_sycl_fattn_kv_buffers & fbuf = ctx.fattn_buffers();
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
const int id = ggml_sycl_get_device();
|
||||
const int nsm = ggml_sycl_info().devices[id].nsm;
|
||||
|
||||
ggml_sycl_pool_alloc<sycl::half> K_f16(pool);
|
||||
ggml_sycl_pool_alloc<sycl::half> V_f16(pool);
|
||||
ggml_sycl_fattn_alloc K_f16(fbuf.K);
|
||||
ggml_sycl_fattn_alloc V_f16(fbuf.V);
|
||||
ggml_sycl_pool_alloc<int> KV_max(pool);
|
||||
ggml_sycl_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_sycl_pool_alloc<sycl::float2> dst_tmp_meta(pool);
|
||||
|
|
|
|||
|
|
@ -1286,6 +1286,23 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
|
|||
explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {}
|
||||
|
||||
~ggml_sycl_pool_leg() {
|
||||
#ifdef DEBUG_SYCL_POOL
|
||||
int n_cached = 0;
|
||||
size_t bytes_cached = 0;
|
||||
for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
|
||||
if (buffer_pool[i].ptr != nullptr) {
|
||||
++n_cached;
|
||||
bytes_cached += buffer_pool[i].size;
|
||||
}
|
||||
}
|
||||
GGML_LOG_INFO("%s: %d buffers, cached = %.2f MiB\n", __func__,
|
||||
n_cached, bytes_cached / 1024.0 / 1024.0);
|
||||
const auto slots = format_slots_in_alloc_order();
|
||||
if (!slots.empty()) {
|
||||
GGML_LOG_INFO("%s: slots MiB: %s\n", __func__, slots.c_str());
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
|
||||
ggml_sycl_buffer & b = buffer_pool[i];
|
||||
if (b.ptr != nullptr) {
|
||||
|
|
@ -1296,6 +1313,26 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
|
|||
GGML_ASSERT(pool_size == 0);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_SYCL_POOL
|
||||
std::string format_slots_in_alloc_order() const {
|
||||
std::string line;
|
||||
char buf[32];
|
||||
bool first = true;
|
||||
for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
|
||||
if (buffer_pool[i].ptr == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!first) {
|
||||
line += '/';
|
||||
}
|
||||
first = false;
|
||||
snprintf(buf, sizeof(buf), "%.2f", buffer_pool[i].size / 1024.0 / 1024.0);
|
||||
line += buf;
|
||||
}
|
||||
return line;
|
||||
}
|
||||
#endif
|
||||
|
||||
void * alloc(size_t size, size_t * actual_size) override {
|
||||
#ifdef DEBUG_sycl_MALLOC
|
||||
int nnz = 0;
|
||||
|
|
@ -1459,6 +1496,10 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
|
|||
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
|
||||
}
|
||||
|
||||
std::unique_ptr<ggml_sycl_fattn_kv_buffers> ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) {
|
||||
return std::unique_ptr<ggml_sycl_fattn_kv_buffers>(new ggml_sycl_fattn_kv_buffers(qptr, device));
|
||||
}
|
||||
|
||||
// TBD pool with virtual memory management
|
||||
// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue