diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h index 02d5f221c..67851ba6f 100644 --- a/ggml/include/gguf.h +++ b/ggml/include/gguf.h @@ -76,10 +76,16 @@ extern "C" { struct ggml_context ** ctx; }; + // callback to simulate or wrap a FILE pointer - read up to `len` bytes at `offset` into `output` and return the number of bytes read + typedef size_t (*gguf_reader_callback_t)(void * userdata, void * output, uint64_t offset, size_t len); + GGML_API struct gguf_context * gguf_init_empty(void); GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); - //GGML_API struct gguf_context * gguf_init_from_buffer(..); + GGML_API struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params); + + // max_chunk_read is the maximum number of bytes that the GGUF code will read at once from the callback, a value of 0 means no limit + GGML_API struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params); GGML_API void gguf_free(struct gguf_context * ctx); @@ -87,7 +93,7 @@ extern "C" { GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx); GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); - GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); // padded to gguf_get_alignment if and only if the gguf_context contains at least one tensor GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx); GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index ab3cc9748..5e1986182 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -228,9 +228,18 @@ struct gguf_context { }; struct gguf_reader { - gguf_reader(FILE * file) : file(file) { - // read the remaining bytes once and update on each read - nbytes_remain = file_remain(file); + gguf_reader( + gguf_reader_callback_t callback, + void * userdata, + size_t max_chunk_read, + uint64_t data_offset = 0, + uint64_t nbytes_remain = 0) + : callback(callback), + userdata(userdata), + max_chunk_read(max_chunk_read), + data_offset(data_offset), + nbytes_remain(nbytes_remain) { + GGML_ASSERT(max_chunk_read > 0); } // helper for remaining bytes in a file @@ -257,12 +266,10 @@ struct gguf_reader { template bool read(T & dst) const { const size_t size = sizeof(dst); - if (nbytes_remain < size) { + if (size > nbytes_remain) { return false; } - const size_t nread = fread(&dst, 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(&dst, size) == size; } template @@ -344,24 +351,71 @@ struct gguf_reader { return false; } dst.resize(static_cast(size)); - const size_t nread = fread(dst.data(), 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(dst.data(), static_cast(size)) == size; } bool read(void * dst, const size_t size) const { if (size > nbytes_remain) { return false; } - const size_t nread = fread(dst, 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(dst, size) == size; + } + + uint64_t tell() const { + return data_offset; + } + + bool seek(uint64_t absolute_offset) const { + const uint64_t end_offset = uint64_t(data_offset) + nbytes_remain; + if (absolute_offset > end_offset) { + return false; + } + + data_offset = absolute_offset; + nbytes_remain = end_offset - absolute_offset; + + return true; } private: - FILE * file; + size_t read_raw(void * dst, size_t size) const { + if (callback == nullptr || size == 0) { + return 0; + } - mutable uint64_t nbytes_remain; + uint8_t * data = static_cast(dst); + size_t total_nread = 0; + bool reached_eof = false; + + while (total_nread < size) { + const size_t chunk_size = std::min(max_chunk_read, size - total_nread); + if (data_offset + total_nread < data_offset) { + break; + } + const size_t nread = callback(userdata, static_cast(data + total_nread), data_offset + total_nread, chunk_size); + total_nread += nread; + if (nread != chunk_size) { + reached_eof = true; + break; + } + } + + data_offset += total_nread; + GGML_ASSERT(total_nread <= nbytes_remain); + nbytes_remain -= total_nread; + + if (reached_eof) { + nbytes_remain = 0; + } + + return total_nread; + } + + gguf_reader_callback_t callback = nullptr; + void * userdata = nullptr; + size_t max_chunk_read = 0; + mutable uint64_t data_offset = 0; + mutable uint64_t nbytes_remain = 0; }; struct gguf_context * gguf_init_empty(void) { @@ -394,12 +448,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vectorinfo.size()) == n_tensors); // we require the data section to be aligned, so take into account any padding - if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) { + if (n_tensors > 0 && !gr.seek(GGML_PAD(gr.tell(), ctx->alignment))) { GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } // store the current file offset - this is where the data section starts - ctx->offset = gguf_ftell(file); + ctx->offset = gr.tell(); // compute the total size of the data section, taking into account the alignment { @@ -844,6 +893,89 @@ struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_para return ctx; } +struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params) { + if (callback == nullptr) { + return nullptr; + } + + const struct gguf_reader gr(callback, userdata, max_chunk_read == 0 ? SIZE_MAX : max_chunk_read, 0, max_expected_size); + return gguf_init_from_reader(gr, params); +} + +struct gguf_file_reader { + FILE * file; + uint64_t offset; +}; + +static size_t gguf_file_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + gguf_file_reader & reader = *static_cast(userdata); + + if (reader.offset != offset) { + if (offset > INT64_MAX || gguf_fseek(reader.file, static_cast(offset), SEEK_SET) != 0) { + return 0; + } + + reader.offset = offset; + } + + const size_t nread = fread(static_cast(output), 1, len, reader.file); + reader.offset += nread; + return nread; +} + +struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params) { + if (!file) { + return nullptr; + } + + const int64_t cur = gguf_ftell(file); + if (cur < 0) { + return nullptr; + } + + gguf_file_reader reader = { + /*.file = */ file, + /*.offset = */ static_cast(cur), + }; + const struct gguf_reader gr(gguf_file_reader_callback, &reader, SIZE_MAX, reader.offset, gguf_reader::file_remain(file)); + return gguf_init_from_reader(gr, params); +} + +struct gguf_buffer_reader { + const uint8_t * data; + size_t size; +}; + +static size_t gguf_buffer_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + const gguf_buffer_reader & reader = *static_cast(userdata); + + if (offset > reader.size || len > reader.size - offset) { + return 0; + } + + const size_t data_offset = static_cast(offset); + const size_t nread = std::min(len, reader.size - data_offset); + memcpy(static_cast(output), reader.data + data_offset, nread); + return nread; +} + +struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params) { + if (data == nullptr || size == 0) { + return nullptr; + } + + gguf_buffer_reader reader = { + /*.data = */ static_cast(data), + /*.size = */ size, + }; + const struct gguf_reader gr(gguf_buffer_reader_callback, &reader, SIZE_MAX, 0, size); + return gguf_init_from_reader(gr, params); +} + struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { FILE * file = ggml_fopen(fname, "rb"); diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp index ed3070dc4..1ae468fbd 100644 --- a/tests/test-gguf.cpp +++ b/tests/test-gguf.cpp @@ -162,6 +162,42 @@ static void helper_write(FILE * file, const void * data, const size_t nbytes) { GGML_ASSERT(fwrite(data, 1, nbytes, file) == nbytes); } +static std::vector read_file_to_buffer(FILE * file) { + GGML_ASSERT(file != nullptr); + GGML_ASSERT(fseek(file, 0, SEEK_END) == 0); + + const long size = ftell(file); + GGML_ASSERT(size >= 0); + + rewind(file); + + std::vector data(static_cast(size)); + GGML_ASSERT(fread(data.data(), 1, data.size(), file) == data.size()); + + rewind(file); + return data; +} + +struct callback_reader_data { + const uint8_t * data; + size_t size; +}; + +static size_t read_buffer_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + const callback_reader_data & reader = *static_cast(userdata); + + if (offset > reader.size || len > reader.size - offset) { + return 0; + } + + const size_t data_offset = static_cast(offset); + const size_t nread = std::min(len, reader.size - data_offset); + memcpy(static_cast(output), reader.data + data_offset, nread); + return nread; +} + static FILE * get_handcrafted_file(const unsigned int seed, const enum handcrafted_file_type hft, const int extra_bytes = 0) { FILE * file = tmpfile(); @@ -1095,10 +1131,29 @@ static bool same_tensor_data(const struct ggml_context * orig, const struct ggml return ok; } -static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) { +enum roundtrip_read_mode { + ROUNDTRIP_READ_MODE_FILE, + ROUNDTRIP_READ_MODE_BUFFER, + ROUNDTRIP_READ_MODE_CALLBACK, +}; + +static const char * roundtrip_read_mode_name(const roundtrip_read_mode mode) { + switch (mode) { + case ROUNDTRIP_READ_MODE_FILE: return "file"; + case ROUNDTRIP_READ_MODE_BUFFER: return "buffer"; + case ROUNDTRIP_READ_MODE_CALLBACK: return "callback"; + } + + GGML_ABORT("fatal error"); +} + +static std::pair test_roundtrip( + ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta, + const roundtrip_read_mode read_mode) { ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); - printf("%s: device=%s, backend=%s, only_meta=%s\n", - __func__, ggml_backend_dev_description(dev), ggml_backend_name(backend), only_meta ? "yes" : "no"); + printf("%s: device=%s, backend=%s, only_meta=%s, read_mode=%s\n", + __func__, ggml_backend_dev_description(dev), ggml_backend_name(backend), + only_meta ? "yes" : "no", roundtrip_read_mode_name(read_mode)); int npass = 0; int ntest = 0; @@ -1133,7 +1188,22 @@ static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned /*no_alloc =*/ false, /*ctx =*/ only_meta ? nullptr : &ctx_1, }; - struct gguf_context * gguf_ctx_1 = gguf_init_from_file_ptr(file, gguf_params); + struct gguf_context * gguf_ctx_1 = nullptr; + const std::vector data = read_mode == ROUNDTRIP_READ_MODE_FILE + ? std::vector() + : read_file_to_buffer(file); + + if (read_mode == ROUNDTRIP_READ_MODE_BUFFER) { + gguf_ctx_1 = gguf_init_from_buffer(data.data(), data.size(), gguf_params); + } else if (read_mode == ROUNDTRIP_READ_MODE_CALLBACK) { + callback_reader_data reader = { + /*.data = */ data.data(), + /*.size = */ data.size(), + }; + gguf_ctx_1 = gguf_init_from_callback(read_buffer_callback, &reader, 4096, 4ull << 30 /* 4GB */, gguf_params); + } else { + gguf_ctx_1 = gguf_init_from_file_ptr(file, gguf_params); + } printf("%s: same_version: ", __func__); if (gguf_get_version(gguf_ctx_0) == gguf_get_version(gguf_ctx_1)) { @@ -1343,7 +1413,17 @@ int main(int argc, char ** argv) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); for (bool only_meta : {true, false}) { - std::pair result = test_roundtrip(dev, seed, only_meta); + std::pair result = test_roundtrip(dev, seed, only_meta, ROUNDTRIP_READ_MODE_FILE); + npass += result.first; + ntest += result.second; + } + { + std::pair result = test_roundtrip(dev, seed, /*only_meta=*/false, ROUNDTRIP_READ_MODE_BUFFER); + npass += result.first; + ntest += result.second; + } + { + std::pair result = test_roundtrip(dev, seed, /*only_meta=*/false, ROUNDTRIP_READ_MODE_CALLBACK); npass += result.first; ntest += result.second; }