From e2bdd6d7aa7cca1af2cce92c2de2b87bf2726a78 Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Fri, 1 May 2026 05:33:28 -0300 Subject: [PATCH] sd: sync to master-591-331cfa5 (#2155) * sd: sync to master-585-44cca3d * sd: sync to master-587-b8bdffc * sd: sync to master-591-331cfa5 --- Makefile | 4 +- otherarch/sdcpp/common/common.cpp | 438 ++++++- otherarch/sdcpp/common/common.h | 28 +- otherarch/sdcpp/convert.cpp | 138 +++ otherarch/sdcpp/denoiser.hpp | 19 +- otherarch/sdcpp/ggml_extend.hpp | 18 +- otherarch/sdcpp/main.cpp | 20 +- otherarch/sdcpp/model.cpp | 911 ++------------ otherarch/sdcpp/model.h | 132 +- otherarch/sdcpp/model_io/binary_io.h | 57 + otherarch/sdcpp/model_io/gguf_io.cpp | 123 ++ otherarch/sdcpp/model_io/gguf_io.h | 17 + .../gguf_reader_ext.h} | 6 +- otherarch/sdcpp/model_io/pickle_io.cpp | 1064 +++++++++++++++++ otherarch/sdcpp/model_io/pickle_io.h | 21 + otherarch/sdcpp/model_io/safetensors_io.cpp | 316 +++++ otherarch/sdcpp/model_io/safetensors_io.h | 17 + otherarch/sdcpp/model_io/tensor_storage.h | 132 ++ otherarch/sdcpp/model_io/torch_legacy_io.cpp | 252 ++++ otherarch/sdcpp/model_io/torch_legacy_io.h | 13 + otherarch/sdcpp/model_io/torch_zip_io.cpp | 140 +++ otherarch/sdcpp/model_io/torch_zip_io.h | 14 + otherarch/sdcpp/sdtype_adapter.cpp | 25 +- otherarch/sdcpp/stable-diffusion.cpp | 396 +++++- otherarch/sdcpp/stable-diffusion.h | 30 + otherarch/sdcpp/tensor.hpp | 265 +++- otherarch/sdcpp/tokenizers/clip_tokenizer.cpp | 2 +- .../sdcpp/tokenizers/mistral_tokenizer.cpp | 2 +- .../sdcpp/tokenizers/qwen2_tokenizer.cpp | 2 +- otherarch/sdcpp/upscaler.cpp | 190 ++- otherarch/sdcpp/upscaler.h | 31 + otherarch/sdcpp/util.cpp | 12 +- otherarch/sdcpp/vae.hpp | 3 +- 33 files changed, 3703 insertions(+), 1135 deletions(-) create mode 100644 otherarch/sdcpp/convert.cpp create mode 100644 otherarch/sdcpp/model_io/binary_io.h create mode 100644 otherarch/sdcpp/model_io/gguf_io.cpp create mode 100644 otherarch/sdcpp/model_io/gguf_io.h rename otherarch/sdcpp/{gguf_reader.hpp => model_io/gguf_reader_ext.h} (98%) create mode 100644 otherarch/sdcpp/model_io/pickle_io.cpp create mode 100644 otherarch/sdcpp/model_io/pickle_io.h create mode 100644 otherarch/sdcpp/model_io/safetensors_io.cpp create mode 100644 otherarch/sdcpp/model_io/safetensors_io.h create mode 100644 otherarch/sdcpp/model_io/tensor_storage.h create mode 100644 otherarch/sdcpp/model_io/torch_legacy_io.cpp create mode 100644 otherarch/sdcpp/model_io/torch_legacy_io.h create mode 100644 otherarch/sdcpp/model_io/torch_zip_io.cpp create mode 100644 otherarch/sdcpp/model_io/torch_zip_io.h create mode 100644 otherarch/sdcpp/upscaler.h diff --git a/Makefile b/Makefile index e7438c57b..9befa1bc8 100644 --- a/Makefile +++ b/Makefile @@ -679,7 +679,7 @@ llama-impl.o: src/llama-impl.cpp src/llama-impl.h budget.o: common/reasoning-budget.cpp common/reasoning-budget.h $(CXX) $(CXXFLAGS) -c $< -o $@ -SDCPP_COMMON_BASENAMES := stable-diffusion.h stable-diffusion.cpp sample-cache.h sample-cache.cpp util.cpp upscaler.cpp model.cpp name_conversion.cpp tokenizers/bpe_tokenizer.cpp tokenizers/bpe_tokenizer.h tokenizers/clip_tokenizer.cpp tokenizers/clip_tokenizer.h tokenizers/mistral_tokenizer.cpp tokenizers/mistral_tokenizer.h tokenizers/qwen2_tokenizer.cpp tokenizers/qwen2_tokenizer.h tokenizers/t5_unigram_tokenizer.cpp tokenizers/t5_unigram_tokenizer.h tokenizers/tokenizer.cpp tokenizers/tokenizer.h tokenizers/tokenize_util.cpp tokenizers/tokenize_util.h thirdparty/zip.c +SDCPP_COMMON_BASENAMES := stable-diffusion.h stable-diffusion.cpp sample-cache.h sample-cache.cpp util.cpp upscaler.h upscaler.cpp model.cpp name_conversion.cpp model_io/gguf_io.cpp model_io/gguf_io.h model_io/gguf_reader_ext.h model_io/pickle_io.cpp model_io/safetensors_io.cpp model_io/safetensors_io.h model_io/tensor_storage.h model_io/torch_legacy_io.cpp model_io/torch_zip_io.cpp tokenizers/bpe_tokenizer.cpp tokenizers/bpe_tokenizer.h tokenizers/clip_tokenizer.cpp tokenizers/clip_tokenizer.h tokenizers/mistral_tokenizer.cpp tokenizers/mistral_tokenizer.h tokenizers/qwen2_tokenizer.cpp tokenizers/qwen2_tokenizer.h tokenizers/t5_unigram_tokenizer.cpp tokenizers/t5_unigram_tokenizer.h tokenizers/tokenizer.cpp tokenizers/tokenizer.h tokenizers/tokenize_util.cpp tokenizers/tokenize_util.h thirdparty/zip.c SDCPP_COMMON_SOURCES := $(foreach f,$(SDCPP_COMMON_BASENAMES),otherarch/sdcpp/$(f)) SDCPP_FLAGS := -I./vendor/nlohmann @@ -736,7 +736,7 @@ mainvk: tools/completion/completion.cpp common/arg.cpp common/speculative.cpp co $(CXX) $(CXXFLAGS) -DGGML_USE_VULKAN -DSD_USE_VULKAN $(filter-out %.h,$^) -o $@ $(LDFLAGS) fitparams: tools/fit-params/fit-params.cpp common/arg.cpp common/speculative.cpp common/ngram-cache.cpp common/ngram-map.cpp common/ngram-mod.cpp common/chat.cpp common/preset.cpp common/download.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib $(CXX) $(CXXFLAGS) -DGGML_USE_VULKAN -DSD_USE_VULKAN $(filter-out %.h,$^) -o $@ $(LDFLAGS) -sdmain: $(SDCPP_COMMON_SOURCES) otherarch/sdcpp/main.cpp otherarch/sdcpp/image_metadata.cpp otherarch/sdcpp/common/log.cpp otherarch/sdcpp/common/media_io.cpp otherarch/sdcpp/common/common.cpp otherarch/sdcpp/version.cpp otherarch/sdcpp/tokenizers/vocab/vocab.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) +sdmain: $(SDCPP_COMMON_SOURCES) otherarch/sdcpp/main.cpp otherarch/sdcpp/image_metadata.cpp otherarch/sdcpp/convert.cpp otherarch/sdcpp/common/log.cpp otherarch/sdcpp/common/media_io.cpp otherarch/sdcpp/common/common.cpp otherarch/sdcpp/version.cpp otherarch/sdcpp/tokenizers/vocab/vocab.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(SDCPP_FLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) whispermain: otherarch/whispercpp/main.cpp otherarch/whispercpp/whisper.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend.o ggml-backend-meta.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/otherarch/sdcpp/common/common.cpp b/otherarch/sdcpp/common/common.cpp index 0235c53de..1a5399b82 100644 --- a/otherarch/sdcpp/common/common.cpp +++ b/otherarch/sdcpp/common/common.cpp @@ -107,47 +107,60 @@ static bool is_absolute_path(const std::string& p) { std::string ArgOptions::wrap_text(const std::string& text, size_t width, size_t indent) { std::ostringstream oss; - size_t line_len = 0; size_t pos = 0; + size_t line_len = 0; while (pos < text.size()) { - // Preserve manual newlines if (text[pos] == '\n') { oss << '\n' << std::string(indent, ' '); - line_len = indent; + line_len = 0; ++pos; continue; } - // Add the character - oss << text[pos]; - ++line_len; - ++pos; + if (std::isspace(static_cast(text[pos]))) { + ++pos; + continue; + } - // If the current line exceeds width, try to break at the last space - if (line_len >= width) { - std::string current = oss.str(); - size_t back = current.size(); + size_t word_start = pos; + while (pos < text.size() && + text[pos] != '\n' && + !std::isspace(static_cast(text[pos]))) { + ++pos; + } - // Find the last space (for a clean break) - while (back > 0 && current[back - 1] != ' ' && current[back - 1] != '\n') - --back; - - // If found a space to break on - if (back > 0 && current[back - 1] != '\n') { - std::string before = current.substr(0, back - 1); - std::string after = current.substr(back); - oss.str(""); - oss.clear(); - oss << before << "\n" - << std::string(indent, ' ') << after; - } else { - // If no space found, just break at width - oss << "\n" - << std::string(indent, ' '); + std::string word = text.substr(word_start, pos - word_start); + while (!word.empty()) { + size_t separator_len = line_len == 0 ? 0 : 1; + if (line_len + separator_len + word.size() <= width) { + if (separator_len > 0) { + oss << ' '; + ++line_len; + } + oss << word; + line_len += word.size(); + word.clear(); + continue; + } + + if (line_len > 0) { + oss << '\n' + << std::string(indent, ' '); + line_len = 0; + continue; + } + + size_t chunk_len = std::min(width, word.size()); + oss << word.substr(0, chunk_len); + line_len = chunk_len; + word.erase(0, chunk_len); + if (!word.empty()) { + oss << '\n' + << std::string(indent, ' '); + line_len = 0; } - line_len = indent; } } @@ -351,7 +364,10 @@ ArgOptions SDContextParams::get_options() { "--lora-model-dir", "lora model directory", &lora_model_dir}, - + {"", + "--hires-upscalers-dir", + "highres fix upscaler model directory", + &hires_upscalers_dir}, {"", "--tensor-type-rules", "weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")", @@ -649,6 +665,7 @@ std::string SDContextParams::to_string() const { << " wtype: " << sd_type_name(wtype) << ",\n" << " tensor_type_rules: \"" << tensor_type_rules << "\",\n" << " lora_model_dir: \"" << lora_model_dir << "\",\n" + << " hires_upscalers_dir: \"" << hires_upscalers_dir << "\",\n" << " photo_maker_path: \"" << photo_maker_path << "\",\n" << " rng_type: " << sd_rng_type_name(rng_type) << ",\n" << " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n" @@ -777,6 +794,12 @@ ArgOptions SDGenerationParams::get_options() { "--pm-id-embed-path", "path to PHOTOMAKER v2 id embed", &pm_id_embed_path}, + {"", + "--hires-upscaler", + "highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent (nearest-exact), " + "Latent (antialiased), Latent (bicubic), Latent (bicubic antialiased), or a model name " + "under --hires-upscalers-dir (default: Latent)", + &hires_upscaler}, }; options.int_options = { @@ -826,6 +849,22 @@ ArgOptions SDGenerationParams::get_options() { "--upscale-tile-size", "tile size for ESRGAN upscaling (default: 128)", &upscale_tile_size}, + {"", + "--hires-width", + "highres fix target width, 0 to use --hires-scale (default: 0)", + &hires_width}, + {"", + "--hires-height", + "highres fix target height, 0 to use --hires-scale (default: 0)", + &hires_height}, + {"", + "--hires-steps", + "highres fix second pass sample steps, 0 to reuse --steps (default: 0)", + &hires_steps}, + {"", + "--hires-upscale-tile-size", + "highres fix upscaler tile size, reserved for model-backed upscalers (default: 128)", + &hires_upscale_tile_size}, }; options.float_options = { @@ -913,6 +952,14 @@ ArgOptions SDGenerationParams::get_options() { "--vae-tile-overlap", "tile overlap for vae tiling, in fraction of tile size (default: 0.5)", &vae_tiling_params.target_overlap}, + {"", + "--hires-scale", + "highres fix scale when target size is not set (default: 2.0)", + &hires_scale}, + {"", + "--hires-denoising-strength", + "highres fix second pass denoising strength (default: 0.7)", + &hires_denoising_strength}, }; options.bool_options = { @@ -936,6 +983,11 @@ ArgOptions SDGenerationParams::get_options() { "process vae in tiles to reduce memory usage", true, &vae_tiling_params.enabled}, + {"", + "--hires", + "enable highres fix", + true, + &hires_enabled}, }; auto on_seed_arg = [&](int argc, const char** argv, int index) { @@ -1424,6 +1476,37 @@ static bool parse_lora_json_field(const json& parent, return true; } +static bool resolve_model_file_from_dir(const std::string& model_name, + const std::string& model_dir, + const std::vector& valid_ext, + const char* label, + std::string& resolved_path) { + if (model_dir.empty()) { + LOG_ERROR("%s directory is empty", label); + return false; + } + if (model_name.empty() || + model_name.find('/') != std::string::npos || + model_name.find('\\') != std::string::npos || + fs::path(model_name).has_root_path() || + fs::path(model_name).has_extension()) { + LOG_ERROR("%s must be a model name without path or extension: %s", label, model_name.c_str()); + return false; + } + + fs::path model_dir_path = model_dir; + for (const auto& ext : valid_ext) { + fs::path try_path = model_dir_path / (model_name + ext); + if (fs::exists(try_path) && fs::is_regular_file(try_path)) { + resolved_path = try_path.lexically_normal().string(); + return true; + } + } + + LOG_ERROR("can not find %s %s in %s", label, model_name.c_str(), model_dir_path.lexically_normal().string().c_str()); + return false; +} + bool SDGenerationParams::from_json_str( const std::string& json_str, const std::function& lora_path_resolver) { @@ -1487,6 +1570,34 @@ bool SDGenerationParams::from_json_str( load_if_exists("increase_ref_index", increase_ref_index); load_if_exists("embed_image_metadata", embed_image_metadata); + if (j.contains("hires") && j["hires"].is_object()) { + const json& hires_json = j["hires"]; + if (hires_json.contains("enabled") && hires_json["enabled"].is_boolean()) { + hires_enabled = hires_json["enabled"]; + } + if (hires_json.contains("upscaler") && hires_json["upscaler"].is_string()) { + hires_upscaler = hires_json["upscaler"]; + } + if (hires_json.contains("scale") && hires_json["scale"].is_number()) { + hires_scale = hires_json["scale"]; + } + if (hires_json.contains("target_width") && hires_json["target_width"].is_number_integer()) { + hires_width = hires_json["target_width"]; + } + if (hires_json.contains("target_height") && hires_json["target_height"].is_number_integer()) { + hires_height = hires_json["target_height"]; + } + if (hires_json.contains("steps") && hires_json["steps"].is_number_integer()) { + hires_steps = hires_json["steps"]; + } + if (hires_json.contains("denoising_strength") && hires_json["denoising_strength"].is_number()) { + hires_denoising_strength = hires_json["denoising_strength"]; + } + if (hires_json.contains("upscale_tile_size") && hires_json["upscale_tile_size"].is_number_integer()) { + hires_upscale_tile_size = hires_json["upscale_tile_size"]; + } + } + auto parse_sample_params_json = [&](const json& sample_json, sd_sample_params_t& target_params, std::vector& target_skip_layers, @@ -1800,7 +1911,7 @@ bool SDGenerationParams::initialize_cache_params() { return true; } -bool SDGenerationParams::resolve(const std::string& lora_model_dir, bool strict) { +bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) { if (high_noise_sample_params.sample_steps <= 0) { high_noise_sample_params.sample_steps = -1; } @@ -1819,6 +1930,27 @@ bool SDGenerationParams::resolve(const std::string& lora_model_dir, bool strict) sample_params.sample_steps = std::clamp(sample_params.sample_steps, 1, 100); } + hires_upscaler_model_path.clear(); + if (hires_enabled) { + if (hires_upscaler.empty()) { + hires_upscaler = "Latent"; + } + resolved_hires_upscaler = str_to_sd_hires_upscaler(hires_upscaler.c_str()); + if (resolved_hires_upscaler == SD_HIRES_UPSCALER_NONE) { + hires_enabled = false; + } else if (resolved_hires_upscaler == SD_HIRES_UPSCALER_COUNT) { + static const std::vector valid_ext = {".gguf", ".safetensors", ".pt", ".pth"}; + if (!resolve_model_file_from_dir(hires_upscaler, + hires_upscalers_dir, + valid_ext, + "hires upscaler", + hires_upscaler_model_path)) { + return false; + } + resolved_hires_upscaler = SD_HIRES_UPSCALER_MODEL; + } + } + prompt_with_lora = prompt; if (!lora_model_dir.empty()) { extract_and_remove_lora(lora_model_dir); @@ -1883,6 +2015,29 @@ bool SDGenerationParams::validate(SDMode mode) { return false; } + if (hires_enabled) { + if (hires_width < 0 || hires_height < 0) { + LOG_ERROR("error: hires target width and height must be >= 0"); + return false; + } + if (hires_scale <= 0.f && hires_width <= 0 && hires_height <= 0) { + LOG_ERROR("error: hires scale must be positive when target size is not set"); + return false; + } + if (hires_steps < 0) { + LOG_ERROR("error: hires steps must be >= 0"); + return false; + } + if (hires_denoising_strength <= 0.f || hires_denoising_strength > 1.f) { + LOG_ERROR("error: hires denoising strength must be in (0.0, 1.0]"); + return false; + } + if (hires_upscale_tile_size < 1) { + LOG_ERROR("error: hires upscale tile size must be positive"); + return false; + } + } + if (mode == UPSCALE) { if (init_image_path.length() == 0) { LOG_ERROR("error: upscale mode needs an init image (--init-img)\n"); @@ -1893,8 +2048,11 @@ bool SDGenerationParams::validate(SDMode mode) { return true; } -bool SDGenerationParams::resolve_and_validate(SDMode mode, const std::string& lora_model_dir, bool strict) { - if (!resolve(lora_model_dir, strict)) { +bool SDGenerationParams::resolve_and_validate(SDMode mode, + const std::string& lora_model_dir, + const std::string& hires_upscalers_dir, + bool strict) { + if (!resolve(lora_model_dir, hires_upscalers_dir, strict)) { return false; } if (!validate(mode)) { @@ -1965,6 +2123,16 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() { params.pm_params = pm_params; params.vae_tiling_params = vae_tiling_params; params.cache = cache_params; + + params.hires.enabled = hires_enabled; + params.hires.upscaler = resolved_hires_upscaler; + params.hires.model_path = hires_upscaler_model_path.empty() ? nullptr : hires_upscaler_model_path.c_str(); + params.hires.scale = hires_scale; + params.hires.target_width = hires_width; + params.hires.target_height = hires_height; + params.hires.steps = hires_steps; + params.hires.denoising_strength = hires_denoising_strength; + params.hires.upscale_tile_size = hires_upscale_tile_size; return params; } @@ -2089,6 +2257,15 @@ std::string SDGenerationParams::to_string() const { << " seed: " << seed << ",\n" << " upscale_repeats: " << upscale_repeats << ",\n" << " upscale_tile_size: " << upscale_tile_size << ",\n" + << " hires: { enabled: " << (hires_enabled ? "true" : "false") + << ", upscaler: \"" << hires_upscaler << "\"" + << ", model_path: \"" << hires_upscaler_model_path << "\"" + << ", scale: " << hires_scale + << ", target_width: " << hires_width + << ", target_height: " << hires_height + << ", steps: " << hires_steps + << ", denoising_strength: " << hires_denoising_strength + << ", upscale_tile_size: " << hires_upscale_tile_size << " },\n" << " vae_tiling_params: { " << vae_tiling_params.enabled << ", " << vae_tiling_params.tile_size_x << ", " @@ -2104,7 +2281,192 @@ std::string version_string() { return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit(); } -std::string get_image_params(const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) { +static std::string safe_json_string(const char* value) { + return value ? value : ""; +} + +static void set_json_basename_if_not_empty(json& target, const char* key, const std::string& path) { + if (!path.empty()) { + target[key] = sd_basename(path); + } +} + +static json build_sampling_metadata_json(const sd_sample_params_t& sample_params, + const std::vector& skip_layers, + const std::vector* custom_sigmas = nullptr) { + json sampling = { + {"steps", sample_params.sample_steps}, + {"eta", sample_params.eta}, + {"shifted_timestep", sample_params.shifted_timestep}, + {"flow_shift", sample_params.flow_shift}, + {"guidance", + { + {"txt_cfg", sample_params.guidance.txt_cfg}, + {"img_cfg", sample_params.guidance.img_cfg}, + {"distilled_guidance", sample_params.guidance.distilled_guidance}, + {"slg", + { + {"scale", sample_params.guidance.slg.scale}, + {"layers", skip_layers}, + {"start", sample_params.guidance.slg.layer_start}, + {"end", sample_params.guidance.slg.layer_end}, + }}, + }}, + }; + if (sample_params.sample_method != SAMPLE_METHOD_COUNT) { + sampling["method"] = safe_json_string(sd_sample_method_name(sample_params.sample_method)); + } + if (sample_params.scheduler != SCHEDULER_COUNT) { + sampling["scheduler"] = safe_json_string(sd_scheduler_name(sample_params.scheduler)); + } + if (custom_sigmas != nullptr) { + sampling["custom_sigmas"] = *custom_sigmas; + } + return sampling; +} + +std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params, + const SDGenerationParams& gen_params, + int64_t seed, + SDMode mode) { + json root; + root["schema"] = "sdcpp.image.params/v1"; + root["mode"] = mode == VID_GEN ? "vid_gen" : "img_gen"; + root["generator"] = { + {"name", "stable-diffusion.cpp"}, + {"version", safe_json_string(sd_version())}, + {"commit", safe_json_string(sd_commit())}, + }; + root["seed"] = seed; + root["width"] = gen_params.get_resolved_width(); + root["height"] = gen_params.get_resolved_height(); + + root["prompt"] = { + {"positive", gen_params.prompt}, + {"negative", gen_params.negative_prompt}, + }; + root["sampling"] = build_sampling_metadata_json(gen_params.sample_params, + gen_params.skip_layers, + &gen_params.custom_sigmas); + + json models; + set_json_basename_if_not_empty(models, "model", ctx_params.model_path); + set_json_basename_if_not_empty(models, "clip_l", ctx_params.clip_l_path); + set_json_basename_if_not_empty(models, "clip_g", ctx_params.clip_g_path); + set_json_basename_if_not_empty(models, "clip_vision", ctx_params.clip_vision_path); + set_json_basename_if_not_empty(models, "t5xxl", ctx_params.t5xxl_path); + set_json_basename_if_not_empty(models, "llm", ctx_params.llm_path); + set_json_basename_if_not_empty(models, "llm_vision", ctx_params.llm_vision_path); + set_json_basename_if_not_empty(models, "diffusion_model", ctx_params.diffusion_model_path); + set_json_basename_if_not_empty(models, "high_noise_diffusion_model", ctx_params.high_noise_diffusion_model_path); + set_json_basename_if_not_empty(models, "vae", ctx_params.vae_path); + set_json_basename_if_not_empty(models, "taesd", ctx_params.taesd_path); + set_json_basename_if_not_empty(models, "control_net", ctx_params.control_net_path); + root["models"] = std::move(models); + + root["clip_skip"] = gen_params.clip_skip; + root["strength"] = gen_params.strength; + root["control_strength"] = gen_params.control_strength; + root["auto_resize_ref_image"] = gen_params.auto_resize_ref_image; + root["increase_ref_index"] = gen_params.increase_ref_index; + if (mode == VID_GEN) { + root["video"] = { + {"frame_count", gen_params.video_frames}, + {"fps", gen_params.fps}, + }; + root["moe_boundary"] = gen_params.moe_boundary; + root["vace_strength"] = gen_params.vace_strength; + root["high_noise_sampling"] = build_sampling_metadata_json(gen_params.high_noise_sample_params, + gen_params.high_noise_skip_layers); + } + + root["rng"] = safe_json_string(sd_rng_type_name(ctx_params.rng_type)); + if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) { + root["sampler_rng"] = safe_json_string(sd_rng_type_name(ctx_params.sampler_rng_type)); + } + + json loras = json::array(); + for (const auto& entry : gen_params.lora_map) { + loras.push_back({ + {"name", sd_basename(entry.first)}, + {"multiplier", entry.second}, + {"is_high_noise", false}, + }); + } + for (const auto& entry : gen_params.high_noise_lora_map) { + loras.push_back({ + {"name", sd_basename(entry.first)}, + {"multiplier", entry.second}, + {"is_high_noise", true}, + }); + } + if (!loras.empty()) { + root["loras"] = std::move(loras); + } + + if (gen_params.hires_enabled) { + root["hires"] = { + {"enabled", gen_params.hires_enabled}, + {"upscaler", gen_params.hires_upscaler}, + {"model", gen_params.hires_upscaler_model_path.empty() ? "" : sd_basename(gen_params.hires_upscaler_model_path)}, + {"scale", gen_params.hires_scale}, + {"target_width", gen_params.hires_width}, + {"target_height", gen_params.hires_height}, + {"steps", gen_params.hires_steps}, + {"denoising_strength", gen_params.hires_denoising_strength}, + {"upscale_tile_size", gen_params.hires_upscale_tile_size}, + }; + } + + if (gen_params.cache_params.mode != SD_CACHE_DISABLED) { + root["cache"] = { + {"requested_mode", gen_params.cache_mode}, + {"requested_option", gen_params.cache_option}, + {"mode", gen_params.cache_params.mode}, + {"scm_mask", gen_params.scm_mask}, + {"scm_policy_dynamic", gen_params.scm_policy_dynamic}, + {"reuse_threshold", gen_params.cache_params.reuse_threshold}, + {"start_percent", gen_params.cache_params.start_percent}, + {"end_percent", gen_params.cache_params.end_percent}, + {"error_decay_rate", gen_params.cache_params.error_decay_rate}, + {"use_relative_threshold", gen_params.cache_params.use_relative_threshold}, + {"reset_error_on_compute", gen_params.cache_params.reset_error_on_compute}, + {"Fn_compute_blocks", gen_params.cache_params.Fn_compute_blocks}, + {"Bn_compute_blocks", gen_params.cache_params.Bn_compute_blocks}, + {"residual_diff_threshold", gen_params.cache_params.residual_diff_threshold}, + {"max_warmup_steps", gen_params.cache_params.max_warmup_steps}, + {"max_cached_steps", gen_params.cache_params.max_cached_steps}, + {"max_continuous_cached_steps", gen_params.cache_params.max_continuous_cached_steps}, + {"taylorseer_n_derivatives", gen_params.cache_params.taylorseer_n_derivatives}, + {"taylorseer_skip_interval", gen_params.cache_params.taylorseer_skip_interval}, + {"spectrum_w", gen_params.cache_params.spectrum_w}, + {"spectrum_m", gen_params.cache_params.spectrum_m}, + {"spectrum_lam", gen_params.cache_params.spectrum_lam}, + {"spectrum_window_size", gen_params.cache_params.spectrum_window_size}, + {"spectrum_flex_window", gen_params.cache_params.spectrum_flex_window}, + {"spectrum_warmup_steps", gen_params.cache_params.spectrum_warmup_steps}, + {"spectrum_stop_percent", gen_params.cache_params.spectrum_stop_percent}, + }; + } + + if (gen_params.vae_tiling_params.enabled) { + root["vae_tiling"] = { + {"enabled", gen_params.vae_tiling_params.enabled}, + {"tile_size_x", gen_params.vae_tiling_params.tile_size_x}, + {"tile_size_y", gen_params.vae_tiling_params.tile_size_y}, + {"target_overlap", gen_params.vae_tiling_params.target_overlap}, + {"rel_size_x", gen_params.vae_tiling_params.rel_size_x}, + {"rel_size_y", gen_params.vae_tiling_params.rel_size_y}, + }; + } + + return root.dump(); +} + +std::string get_image_params(const SDContextParams& ctx_params, + const SDGenerationParams& gen_params, + int64_t seed, + SDMode mode) { std::string parameter_string; if (gen_params.prompt_with_lora.size() != 0) { parameter_string += gen_params.prompt_with_lora + "\n"; @@ -2117,7 +2479,7 @@ std::string get_image_params(const SDContextParams& ctx_params, const SDGenerati parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", "; parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", "; if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) { - parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", "; + parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.slg.scale) + ", "; parameter_string += "Skip layers: ["; for (const auto& layer : gen_params.skip_layers) { parameter_string += std::to_string(layer) + ", "; @@ -2162,6 +2524,14 @@ std::string get_image_params(const SDContextParams& ctx_params, const SDGenerati if (gen_params.clip_skip != -1) { parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", "; } + if (gen_params.hires_enabled) { + parameter_string += "Hires upscale: " + gen_params.hires_upscaler + ", "; + parameter_string += "Hires scale: " + std::to_string(gen_params.hires_scale) + ", "; + parameter_string += "Hires resize: " + std::to_string(gen_params.hires_width) + "x" + std::to_string(gen_params.hires_height) + ", "; + parameter_string += "Hires steps: " + std::to_string(gen_params.hires_steps) + ", "; + parameter_string += "Denoising strength: " + std::to_string(gen_params.hires_denoising_strength) + ", "; + } parameter_string += "Version: stable-diffusion.cpp"; + parameter_string += ", SDCPP: " + build_sdcpp_image_metadata_json(ctx_params, gen_params, seed, mode); return parameter_string; } diff --git a/otherarch/sdcpp/common/common.h b/otherarch/sdcpp/common/common.h index 5afe89b34..c4498c352 100644 --- a/otherarch/sdcpp/common/common.h +++ b/otherarch/sdcpp/common/common.h @@ -101,6 +101,7 @@ struct SDContextParams { sd_type_t wtype = SD_TYPE_COUNT; std::string tensor_type_rules; std::string lora_model_dir = "."; + std::string hires_upscalers_dir; std::map embedding_map; std::vector embedding_vec; @@ -190,12 +191,23 @@ struct SDGenerationParams { int upscale_repeats = 1; int upscale_tile_size = 128; + bool hires_enabled = false; + std::string hires_upscaler = "Latent"; + std::string hires_upscaler_model_path; + float hires_scale = 2.f; + int hires_width = 0; + int hires_height = 0; + int hires_steps = 0; + float hires_denoising_strength = 0.7f; + int hires_upscale_tile_size = 128; + std::map lora_map; std::map high_noise_lora_map; // Derived and normalized fields. std::string prompt_with_lora; // for metadata record only std::vector lora_vec; + sd_hires_upscaler_t resolved_hires_upscaler; // Owned execution payload. SDImageOwner init_image; @@ -225,15 +237,25 @@ struct SDGenerationParams { void set_width_and_height_if_unset(int w, int h); int get_resolved_width() const; int get_resolved_height() const; - bool resolve(const std::string& lora_model_dir, bool strict = false); + bool resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict = false); bool validate(SDMode mode); - bool resolve_and_validate(SDMode mode, const std::string& lora_model_dir, bool strict = false); + bool resolve_and_validate(SDMode mode, + const std::string& lora_model_dir, + const std::string& hires_upscalers_dir, + bool strict = false); sd_img_gen_params_t to_sd_img_gen_params_t(); sd_vid_gen_params_t to_sd_vid_gen_params_t(); std::string to_string() const; }; std::string version_string(); -std::string get_image_params(const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed); +std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params, + const SDGenerationParams& gen_params, + int64_t seed, + SDMode mode = IMG_GEN); +std::string get_image_params(const SDContextParams& ctx_params, + const SDGenerationParams& gen_params, + int64_t seed, + SDMode mode = IMG_GEN); #endif // __EXAMPLES_COMMON_COMMON_H__ diff --git a/otherarch/sdcpp/convert.cpp b/otherarch/sdcpp/convert.cpp new file mode 100644 index 000000000..7cae8df0f --- /dev/null +++ b/otherarch/sdcpp/convert.cpp @@ -0,0 +1,138 @@ +#include +#include +#include +#include + +#include "model.h" +#include "model_io/gguf_io.h" +#include "model_io/safetensors_io.h" +#include "util.h" + +#include "ggml-cpu.h" + +static ggml_type get_export_tensor_type(ModelLoader& model_loader, + const TensorStorage& tensor_storage, + ggml_type type, + const TensorTypeRules& tensor_type_rules) { + const std::string& name = tensor_storage.name; + ggml_type tensor_type = tensor_storage.type; + ggml_type dst_type = type; + + for (const auto& tensor_type_rule : tensor_type_rules) { + std::regex pattern(tensor_type_rule.first); + if (std::regex_search(name, pattern)) { + dst_type = tensor_type_rule.second; + break; + } + } + + if (model_loader.tensor_should_be_converted(tensor_storage, dst_type)) { + tensor_type = dst_type; + } + + return tensor_type; +} + +static bool load_tensors_for_export(ModelLoader& model_loader, + ggml_context* ggml_ctx, + ggml_type type, + const TensorTypeRules& tensor_type_rules, + std::vector& tensors) { + std::mutex tensor_mutex; + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + ggml_type tensor_type = get_export_tensor_type(model_loader, tensor_storage, type, tensor_type_rules); + + std::lock_guard lock(tensor_mutex); + ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); + if (tensor == nullptr) { + LOG_ERROR("ggml_new_tensor failed"); + return false; + } + ggml_set_name(tensor, name.c_str()); + + if (!tensor->data) { + GGML_ASSERT(ggml_nelements(tensor) == 0); + // Avoid crashing writers by setting a dummy pointer for zero-sized tensors. + LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str()); + tensor->data = ggml_get_mem_buffer(ggml_ctx); + } + + TensorWriteInfo write_info; + write_info.tensor = tensor; + write_info.n_dims = tensor_storage.n_dims; + for (int i = 0; i < tensor_storage.n_dims; ++i) { + write_info.ne[i] = tensor_storage.ne[i]; + } + + *dst_tensor = tensor; + tensors.push_back(std::move(write_info)); + + return true; + }; + + bool success = model_loader.load_tensors(on_new_tensor_cb); + LOG_INFO("load tensors done"); + return success; +} + +bool convert(const char* input_path, + const char* vae_path, + const char* output_path, + sd_type_t output_type, + const char* tensor_type_rules, + bool convert_name) { + ModelLoader model_loader; + + if (!model_loader.init_from_file(input_path)) { + LOG_ERROR("init model loader from file failed: '%s'", input_path); + return false; + } + + if (vae_path != nullptr && strlen(vae_path) > 0) { + if (!model_loader.init_from_file(vae_path, "vae.")) { + LOG_ERROR("init model loader from file failed: '%s'", vae_path); + return false; + } + } + if (convert_name) { + model_loader.convert_tensors_name(); + } + + ggml_type type = (ggml_type)output_type; + bool output_is_safetensors = ends_with(output_path, ".safetensors"); + TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules); + + auto backend = ggml_backend_cpu_init(); + size_t mem_size = 1 * 1024 * 1024; // for padding + mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead(); + mem_size += model_loader.get_params_mem_size(backend, type); + LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f); + ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false}); + + if (ggml_ctx == nullptr) { + LOG_ERROR("ggml_init failed for converter"); + ggml_backend_free(backend); + return false; + } + + std::vector tensors; + bool success = load_tensors_for_export(model_loader, ggml_ctx, type, type_rules, tensors); + ggml_backend_free(backend); + + std::string error; + if (success) { + if (output_is_safetensors) { + success = write_safetensors_file(output_path, tensors, &error); + } else { + success = write_gguf_file(output_path, tensors, &error); + } + } + + if (!success && !error.empty()) { + LOG_ERROR("%s", error.c_str()); + } + + ggml_free(ggml_ctx); + return success; +} diff --git a/otherarch/sdcpp/denoiser.hpp b/otherarch/sdcpp/denoiser.hpp index 14b6d3beb..a6e81d597 100644 --- a/otherarch/sdcpp/denoiser.hpp +++ b/otherarch/sdcpp/denoiser.hpp @@ -1523,12 +1523,10 @@ static sd::Tensor sample_ddim_trailing(denoise_cb_t model, const std::vector& sigmas, std::shared_ptr rng, float eta) { - int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { - - float sigma = sigmas[i]; - float sigma_to = sigmas[i + 1]; + float sigma = sigmas[i]; + float sigma_to = sigmas[i + 1]; auto model_output_opt = model(x, sigma, i + 1); if (model_output_opt.empty()) { @@ -1551,12 +1549,11 @@ static sd::Tensor sample_ddim_trailing(denoise_cb_t model, float std_dev_t = eta * std::sqrt(variance); x = pred_original_sample + - std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2))/ alpha_prod_t_prev) * model_output; + std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output; if (eta > 0) { - x+= std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor::randn_like(x, rng); + x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor::randn_like(x, rng); } - } return x; } @@ -1584,8 +1581,10 @@ static sd::Tensor sample_tcd(denoise_cb_t model, auto get_timestep_from_sigma = [&](float s) -> int { auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s); - if (it == compvis_sigmas.begin()) return 0; - if (it == compvis_sigmas.end()) return TIMESTEPS - 1; + if (it == compvis_sigmas.begin()) + return 0; + if (it == compvis_sigmas.end()) + return TIMESTEPS - 1; int idx_high = static_cast(std::distance(compvis_sigmas.begin(), it)); int idx_low = idx_high - 1; if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) { @@ -1596,7 +1595,6 @@ static sd::Tensor sample_tcd(denoise_cb_t model, int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { - float sigma_to = sigmas[i + 1]; int prev_timestep = get_timestep_from_sigma(sigma_to); int timestep_s = (int)floor((1 - eta) * prev_timestep); @@ -1626,7 +1624,6 @@ static sd::Tensor sample_tcd(denoise_cb_t model, x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x + std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor::randn_like(x, rng); } - } return x; } diff --git a/otherarch/sdcpp/ggml_extend.hpp b/otherarch/sdcpp/ggml_extend.hpp index 7810dec7e..b559f58bb 100644 --- a/otherarch/sdcpp/ggml_extend.hpp +++ b/otherarch/sdcpp/ggml_extend.hpp @@ -2758,16 +2758,16 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( bool is_conv, WeightAdapter::ForwardParams::conv2d_params_t conv_params, float scale) { - GGML_ASSERT((w1 != NULL || (w1a != NULL && w1b != NULL))); - GGML_ASSERT((w2 != NULL || (w2a != NULL && w2b != NULL))); + GGML_ASSERT((w1 != nullptr || (w1a != nullptr && w1b != nullptr))); + GGML_ASSERT((w2 != nullptr || (w2a != nullptr && w2b != nullptr))); - int uq = (w1 != NULL) ? (int)w1->ne[0] : (int)w1a->ne[0]; - int up = (w1 != NULL) ? (int)w1->ne[1] : (int)w1b->ne[1]; + int uq = (w1 != nullptr) ? (int)w1->ne[0] : (int)w1a->ne[0]; + int up = (w1 != nullptr) ? (int)w1->ne[1] : (int)w1b->ne[1]; int q_actual = is_conv ? (int)h->ne[2] : (int)h->ne[0]; int vq = q_actual / uq; - int vp = (w2 != NULL) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1]) + int vp = (w2 != nullptr) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1]) : (int)w2a->ne[1]; GGML_ASSERT(q_actual == (uq * vq) && "Input dimension mismatch for LoKR split"); @@ -2803,7 +2803,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( #endif ggml_tensor* h_split = ggml_reshape_3d(ctx, h, vq, uq * merge_batch_uq, batch / merge_batch_uq); - if (w2 != NULL) { + if (w2 != nullptr) { hb = ggml_mul_mat(ctx, w2, h_split); } else { hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_split)); @@ -2816,7 +2816,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( hb_t = ggml_reshape_3d(ctx, hb_t, uq, vp * merge_batch_vp, batch / merge_batch_vp); ggml_tensor* hc_t; - if (w1 != NULL) { + if (w1 != nullptr) { hc_t = ggml_mul_mat(ctx, w1, hb_t); } else { hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t)); @@ -2834,7 +2834,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( // 1. Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch] ggml_tensor* h_split = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch); - if (w2 != NULL) { + if (w2 != nullptr) { hb = ggml_ext_conv_2d(ctx, h_split, w2, nullptr, conv_params.s0, conv_params.s1, @@ -2902,7 +2902,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( ggml_tensor* hb_merged = ggml_reshape_2d(ctx, hb, w_out * h_out * vp, uq * batch); ggml_tensor* hc_t; ggml_tensor* hb_merged_t = ggml_cont(ctx, ggml_transpose(ctx, hb_merged)); - if (w1 != NULL) { + if (w1 != nullptr) { // Would be great to be able to transpose w1 instead to avoid transposing both hb and hc hc_t = ggml_mul_mat(ctx, w1, hb_merged_t); } else { diff --git a/otherarch/sdcpp/main.cpp b/otherarch/sdcpp/main.cpp index a5b0037b6..27513f475 100644 --- a/otherarch/sdcpp/main.cpp +++ b/otherarch/sdcpp/main.cpp @@ -278,7 +278,9 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP bool valid = cli_params.resolve_and_validate(); if (valid && cli_params.mode != METADATA) { valid = ctx_params.resolve_and_validate(cli_params.mode) && - gen_params.resolve_and_validate(cli_params.mode, ctx_params.lora_model_dir); + gen_params.resolve_and_validate(cli_params.mode, + ctx_params.lora_model_dir, + ctx_params.hires_upscalers_dir); } if (!valid) { @@ -431,10 +433,11 @@ bool save_results(const SDCliParams& cli_params, if (!img.data) return false; - std::string params = gen_params.embed_image_metadata - ? get_image_params(ctx_params, gen_params, gen_params.seed + idx) - : ""; - const bool ok = write_image_to_file(path.string(), img.data, img.width, img.height, img.channel, params, 90); + const int64_t metadata_seed = cli_params.mode == VID_GEN ? gen_params.seed : gen_params.seed + idx; + std::string params = gen_params.embed_image_metadata + ? get_image_params(ctx_params, gen_params, metadata_seed, cli_params.mode) + : ""; + const bool ok = write_image_to_file(path.string(), img.data, img.width, img.height, img.channel, params, 90); LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure"); return ok; }; @@ -688,6 +691,13 @@ int main(int argc, const char* argv[]) { vae_decode_only = false; } + if (gen_params.hires_enabled && + (gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_MODEL || + gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_LANCZOS || + gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_NEAREST)) { + vae_decode_only = false; + } + sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview); SDImageVec results; diff --git a/otherarch/sdcpp/model.cpp b/otherarch/sdcpp/model.cpp index 62a8191c2..d1bb0c90d 100644 --- a/otherarch/sdcpp/model.cpp +++ b/otherarch/sdcpp/model.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -12,8 +13,11 @@ #include #include -#include "gguf_reader.hpp" #include "model.h" +#include "model_io/gguf_io.h" +#include "model_io/safetensors_io.h" +#include "model_io/torch_legacy_io.h" +#include "model_io/torch_zip_io.h" #include "stable-diffusion.h" #include "util.h" @@ -21,6 +25,7 @@ #include "ggml-backend.h" #include "ggml-cpu.h" #include "ggml.h" +#include "zip.h" #include "name_conversion.h" #include "stable-diffusion.h" @@ -37,40 +42,6 @@ #include "ggml-opencl.h" #endif -#define ST_HEADER_SIZE_LEN 8 - -uint64_t read_u64(uint8_t* buffer) { - // little endian - uint64_t value = 0; - value |= static_cast(buffer[7]) << 56; - value |= static_cast(buffer[6]) << 48; - value |= static_cast(buffer[5]) << 40; - value |= static_cast(buffer[4]) << 32; - value |= static_cast(buffer[3]) << 24; - value |= static_cast(buffer[2]) << 16; - value |= static_cast(buffer[1]) << 8; - value |= static_cast(buffer[0]); - return value; -} - -int32_t read_int(uint8_t* buffer) { - // little endian - int value = 0; - value |= buffer[3] << 24; - value |= buffer[2] << 16; - value |= buffer[1] << 8; - value |= buffer[0]; - return value; -} - -uint16_t read_short(uint8_t* buffer) { - // little endian - uint16_t value = 0; - value |= buffer[1] << 8; - value |= buffer[0]; - return value; -} - /*================================================= Preprocess ==================================================*/ const char* unused_tensors[] = { @@ -110,7 +81,7 @@ const char* unused_tensors[] = { "first_stage_model.bn.", }; -bool is_unused_tensor(std::string name) { +bool is_unused_tensor(const std::string& name) { for (size_t i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) { if (starts_with(name, unused_tensors[i])) { return true; @@ -264,80 +235,8 @@ void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) { tensor_storage_map[tensor_storage.name] = tensor_storage; } -bool is_zip_file(const std::string& file_path) { - zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == nullptr) { - return false; - } - zip_close(zip); - return true; -} - -bool is_gguf_file(const std::string& file_path) { - std::ifstream file(sd_get_u8path(file_path), std::ios::binary); - if (!file.is_open()) { - return false; - } - - char magic[4]; - - file.read(magic, sizeof(magic)); - if (!file) { - return false; - } - for (uint32_t i = 0; i < sizeof(magic); i++) { - if (magic[i] != GGUF_MAGIC[i]) { - return false; - } - } - - return true; -} - -bool is_safetensors_file(const std::string& file_path) { - std::ifstream file(sd_get_u8path(file_path), std::ios::binary); - if (!file.is_open()) { - return false; - } - - // get file size - file.seekg(0, file.end); - size_t file_size_ = file.tellg(); - file.seekg(0, file.beg); - - // read header size - if (file_size_ <= ST_HEADER_SIZE_LEN) { - return false; - } - - uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; - file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); - if (!file) { - return false; - } - - size_t header_size_ = read_u64(header_size_buf); - if (header_size_ >= file_size_ || header_size_ <= 2) { - return false; - } - - // read header - std::vector header_buf; - header_buf.resize(header_size_ + 1); - header_buf[header_size_] = '\0'; - file.read(header_buf.data(), header_size_); - if (!file) { - return false; - } - try { - nlohmann::json header_ = nlohmann::json::parse(header_buf.data()); - } catch (const std::exception&) { - return false; - } - return true; -} - bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) { + return [&](const std::string& file_path) { // kcpp u8 file path if (is_directory(file_path)) { LOG_INFO("load %s using diffusers format", file_path.c_str()); return init_from_diffusers_file(file_path, prefix); @@ -347,9 +246,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string } else if (is_safetensors_file(file_path)) { LOG_INFO("load %s using safetensors format", file_path.c_str()); return init_from_safetensors_file(file_path, prefix); - } else if (is_zip_file(file_path)) { - LOG_INFO("load %s using checkpoint format", file_path.c_str()); - return init_from_ckpt_file(file_path, prefix); + } else if (is_torch_zip_file(file_path)) { + LOG_INFO("load %s using torch zip format", file_path.c_str()); + return init_from_torch_zip_file(file_path, prefix); + } else if (init_from_torch_legacy_file(file_path, prefix)) { + LOG_INFO("load %s using torch legacy format", file_path.c_str()); + return true; } else { if (file_exists(file_path)) { LOG_WARN("unknown format %s", file_path.c_str()); @@ -358,6 +260,7 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string } return false; } + }(sd_get_u8path(file_path)); // kcpp u8 file path } void ModelLoader::convert_tensors_name() { @@ -389,249 +292,138 @@ bool ModelLoader::init_from_file_and_convert_name(const std::string& file_path, bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) { LOG_DEBUG("init from '%s'", file_path.c_str()); + + std::vector tensor_storages; + std::string error; + if (!read_gguf_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); + return false; + } + file_paths_.push_back(file_path); size_t file_index = file_paths_.size() - 1; - gguf_context* ctx_gguf_ = nullptr; - ggml_context* ctx_meta_ = nullptr; + for (auto& tensor_storage : tensor_storages) { + // LOG_DEBUG("%s", tensor_storage.name.c_str()); - ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_}); - if (!ctx_gguf_) { - LOG_ERROR("failed to open '%s' with gguf_init_from_file. Try to open it with GGUFReader.", file_path.c_str()); - GGUFReader gguf_reader; - if (!gguf_reader.load(file_path)) { - LOG_ERROR("failed to open '%s' with GGUFReader.", file_path.c_str()); - return false; + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; } - - size_t data_offset = gguf_reader.data_offset(); - for (const auto& gguf_tensor_info : gguf_reader.tensors()) { - std::string name = gguf_tensor_info.name; - if (!starts_with(name, prefix)) { - name = prefix + name; - } - - TensorStorage tensor_storage( - name, - gguf_tensor_info.type, - gguf_tensor_info.shape.data(), - static_cast(gguf_tensor_info.shape.size()), - file_index, - data_offset + gguf_tensor_info.offset); - - // LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str()); - - add_tensor_storage(tensor_storage); - } - - return true; - } - - int n_tensors = static_cast(gguf_get_n_tensors(ctx_gguf_)); - - size_t total_size = 0; - size_t data_offset = gguf_get_data_offset(ctx_gguf_); - for (int i = 0; i < n_tensors; i++) { - std::string name = gguf_get_tensor_name(ctx_gguf_, i); - ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str()); - size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i); - - // LOG_DEBUG("%s", name.c_str()); - - if (!starts_with(name, prefix)) { - name = prefix + name; - } - - TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset); - - GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); + tensor_storage.file_index = file_index; add_tensor_storage(tensor_storage); } - gguf_free(ctx_gguf_); - ggml_free(ctx_meta_); - return true; } /*================================================= SafeTensorsModelLoader ==================================================*/ -ggml_type str_to_ggml_type(const std::string& dtype) { - ggml_type ttype = GGML_TYPE_COUNT; - if (dtype == "F16") { - ttype = GGML_TYPE_F16; - } else if (dtype == "BF16") { - ttype = GGML_TYPE_BF16; - } else if (dtype == "F32") { - ttype = GGML_TYPE_F32; - } else if (dtype == "F64") { - ttype = GGML_TYPE_F32; - } else if (dtype == "F8_E4M3") { - ttype = GGML_TYPE_F16; - } else if (dtype == "F8_E5M2") { - ttype = GGML_TYPE_F16; - } else if (dtype == "I64") { - ttype = GGML_TYPE_I32; - } - return ttype; -} - -// https://huggingface.co/docs/safetensors/index bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) { LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str()); + + std::vector tensor_storages; + std::string error; + if (!read_safetensors_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); + return false; + } + file_paths_.push_back(file_path); size_t file_index = file_paths_.size() - 1; - std::ifstream file(sd_get_u8path(file_path), std::ios::binary); - if (!file.is_open()) { - LOG_ERROR("failed to open '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; - } - // get file size - file.seekg(0, file.end); - size_t file_size_ = file.tellg(); - file.seekg(0, file.beg); - - // read header size - if (file_size_ <= ST_HEADER_SIZE_LEN) { - LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; - } - - uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; - file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); - if (!file) { - LOG_ERROR("read safetensors header size failed: '%s'", file_path.c_str()); - return false; - } - - size_t header_size_ = read_u64(header_size_buf); - if (header_size_ >= file_size_) { - LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; - } - - // read header - std::vector header_buf; - header_buf.resize(header_size_ + 1); - header_buf[header_size_] = '\0'; - file.read(header_buf.data(), header_size_); - if (!file) { - LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str()); - file_paths_.pop_back(); - return false; - } - - nlohmann::json header_; - try { - header_ = nlohmann::json::parse(header_buf.data()); - } catch (const std::exception&) { - LOG_ERROR("parsing safetensors header failed", file_path.c_str()); - file_paths_.pop_back(); - return false; - } - - for (auto& item : header_.items()) { - std::string name = item.key(); - nlohmann::json tensor_info = item.value(); - // LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str()); - - if (name == "__metadata__") { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { continue; } - if (is_unused_tensor(name)) { - continue; + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; } + tensor_storage.file_index = file_index; - std::string dtype = tensor_info["dtype"]; - nlohmann::json shape = tensor_info["shape"]; - - if (dtype == "U8") { - continue; - } - - size_t begin = tensor_info["data_offsets"][0].get(); - size_t end = tensor_info["data_offsets"][1].get(); - - ggml_type type = str_to_ggml_type(dtype); - if (type == GGML_TYPE_COUNT) { - LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str()); - return false; - } - - if (shape.size() > SD_MAX_DIMS) { - LOG_ERROR("invalid tensor '%s'", name.c_str()); - return false; - } - - int n_dims = (int)shape.size(); - int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - for (int i = 0; i < n_dims; i++) { - ne[i] = shape[i].get(); - } - - if (n_dims == 5) { - n_dims = 4; - ne[0] = ne[0] * ne[1]; - ne[1] = ne[2]; - ne[2] = ne[3]; - ne[3] = ne[4]; - } - - // ggml_n_dims returns 1 for scalars - if (n_dims == 0) { - n_dims = 1; - } - - if (!starts_with(name, prefix)) { - name = prefix + name; - } - - name = kcpp_fix_wrong_img_tensor_name(name); - - TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); - tensor_storage.reverse_ne(); - - size_t tensor_data_size = end - begin; - - bool tensor_size_ok; - if (dtype == "F8_E4M3") { - tensor_storage.is_f8_e4m3 = true; - // f8 -> f16 - tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); - } else if (dtype == "F8_E5M2") { - tensor_storage.is_f8_e5m2 = true; - // f8 -> f16 - tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); - } else if (dtype == "F64") { - tensor_storage.is_f64 = true; - // f64 -> f32 - tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); - } else if (dtype == "I64") { - tensor_storage.is_i64 = true; - // i64 -> i32 - tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); - } else { - tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size); - } - if (!tensor_size_ok) { - LOG_ERROR("size mismatch for tensor '%s' (%s)\n", name.c_str(), dtype.c_str()); - return false; - } + tensor_storage.name = kcpp_fix_wrong_img_tensor_name(tensor_storage.name); add_tensor_storage(tensor_storage); - // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); + // LOG_DEBUG("%s", tensor_storage.to_string().c_str()); } return true; } +/*================================================= TorchLegacyModelLoader ==================================================*/ + +bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix) { + LOG_DEBUG("init from torch legacy '%s'", file_path.c_str()); + + std::vector tensor_storages; + std::string error; + if (!read_torch_legacy_file(file_path, tensor_storages, &error)) { + if ((!error.empty()) && (ends_with(file_path, ".pt") || ends_with(file_path, ".pth"))) { + LOG_WARN("%s", error.c_str()); + } + return false; + } + + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; + + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } + + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; + } + tensor_storage.file_index = file_index; + + add_tensor_storage(tensor_storage); + } + + return true; +} + +/*================================================= TorchZipModelLoader ==================================================*/ + +bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) { + LOG_DEBUG("init from '%s'", file_path.c_str()); + + std::vector tensor_storages; + std::string error; + if (!read_torch_zip_file(file_path, tensor_storages, &error)) { + LOG_ERROR("%s", error.c_str()); + return false; + } + + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; + + for (auto& tensor_storage : tensor_storages) { + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; + } + tensor_storage.file_index = file_index; + + add_tensor_storage(tensor_storage); + + // LOG_DEBUG("%s", tensor_storage.to_string().c_str()); + } + + return true; +} + +bool ModelLoader::has_diffusion_model_tensors() +{ + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (tensor_storage.name.find("model.diffusion_model.") != std::string::npos) { + return true; + } + } + return false; +} + /*================================================= DiffusersModelLoader ==================================================*/ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) { @@ -658,377 +450,6 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s return true; } -/*================================================= CkptModelLoader ==================================================*/ - -// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 -// 0: \x80 PROTO 2 -// 2: } EMPTY_DICT -// 3: q BINPUT 0 -// 5: ( MARK -// 6: X BINUNICODE 'epoch' -// 16: q BINPUT 1 -// 18: K BININT1 6 -// 20: X BINUNICODE 'global_step' -// 36: q BINPUT 2 -// 38: J BININT 470000 -// 43: X BINUNICODE 'pytorch-lightning_version' -// 73: q BINPUT 3 -// 75: X BINUNICODE '1.4.2' -// 85: q BINPUT 4 -// 87: X BINUNICODE 'state_dict' -// 102: q BINPUT 5 -// 104: } EMPTY_DICT -// 105: q BINPUT 6 -// 107: ( MARK -// 108: X BINUNICODE 'betas' -// 118: q BINPUT 7 -// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' -// 153: q BINPUT 8 -// 155: ( MARK -// 156: ( MARK -// 157: X BINUNICODE 'storage' -// 169: q BINPUT 9 -// 171: c GLOBAL 'torch FloatStorage' -// 191: q BINPUT 10 -// 193: X BINUNICODE '0' -// 199: q BINPUT 11 -// 201: X BINUNICODE 'cpu' -// 209: q BINPUT 12 -// 211: M BININT2 1000 -// 214: t TUPLE (MARK at 156) -// 215: q BINPUT 13 -// 217: Q BINPERSID -// 218: K BININT1 0 -// 220: M BININT2 1000 -// ............................... -// 3201: q BINPUT 250 -// 3203: R REDUCE -// 3204: q BINPUT 251 -// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' -// 3264: q BINPUT 252 -// 3266: h BINGET 8 -// 3268: ( MARK -// 3269: ( MARK -// 3270: h BINGET 9 -// 3272: h BINGET 10 -// 3274: X BINUNICODE '30' -// 3281: q BINPUT 253 -// 3283: h BINGET 12 -// 3285: J BININT 102400 -// 3290: t TUPLE (MARK at 3269) -// 3291: q BINPUT 254 -// 3293: Q BINPERSID -// 3294: K BININT1 0 -// 3296: ( MARK -// 3297: M BININT2 320 -// 3300: M BININT2 320 -// 3303: K BININT1 1 -// 3305: K BININT1 1 -// 3307: t TUPLE (MARK at 3296) -// 3308: q BINPUT 255 -// 3310: ( MARK -// 3311: M BININT2 320 -// 3314: K BININT1 1 -// 3316: K BININT1 1 -// 3318: K BININT1 1 -// 3320: t TUPLE (MARK at 3310) -// 3321: r LONG_BINPUT 256 -// 3326: \x89 NEWFALSE -// 3327: h BINGET 16 -// 3329: ) EMPTY_TUPLE -// 3330: R REDUCE -// 3331: r LONG_BINPUT 257 -// 3336: t TUPLE (MARK at 3268) -// 3337: r LONG_BINPUT 258 -// 3342: R REDUCE -// 3343: r LONG_BINPUT 259 -// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' -// 3404: r LONG_BINPUT 260 -// 3409: h BINGET 8 -// 3411: ( MARK -// 3412: ( MARK -// 3413: h BINGET 9 -// 3415: h BINGET 10 -// 3417: X BINUNICODE '31' - -struct PickleTensorReader { - enum ReadPhase { - READ_NAME, - READ_DATA, - CHECK_SIZE, - READ_DIMENS - }; - ReadPhase phase = READ_NAME; - size_t entry_size = 0; - int32_t nelements = 0; - - TensorStorage tensor_storage; - - static ggml_type global_type; // all pickle_tensors data type - static bool read_global_type; - - bool read_int_value(uint32_t value) { - if (phase == CHECK_SIZE) { - if (entry_size == value * ggml_type_size(tensor_storage.type)) { - nelements = value; - phase = READ_DIMENS; - return true; - } else { - phase = READ_NAME; - } - } else if (phase == READ_DIMENS) { - if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) { // too many dimens - phase = READ_NAME; - tensor_storage.n_dims = 0; - } - if (nelements % value == 0) { - tensor_storage.ne[tensor_storage.n_dims] = value; - tensor_storage.n_dims++; - } - } - return false; - } - - void read_global(const std::string& str) { - if (str == "FloatStorage") { - if (read_global_type) { - global_type = GGML_TYPE_F32; - read_global_type = false; - } - tensor_storage.type = GGML_TYPE_F32; - } else if (str == "HalfStorage") { - if (read_global_type) { - global_type = GGML_TYPE_F16; - read_global_type = false; - } - tensor_storage.type = GGML_TYPE_F16; - } - } - - void read_string(const std::string& str, zip_t* zip, std::string dir) { - if (str == "storage") { - read_global_type = true; - } else if (str != "state_dict") { - if (phase == READ_DATA) { - std::string entry_name = dir + "data/" + std::string(str); - - size_t i, n = zip_entries_total(zip); - for (i = 0; i < n; ++i) { - zip_entry_openbyindex(zip, i); - { - std::string name = zip_entry_name(zip); - if (name == entry_name) { - tensor_storage.index_in_zip = (int)i; - entry_size = zip_entry_size(zip); - zip_entry_close(zip); - break; - } - } - zip_entry_close(zip); - } - - phase = entry_size > 0 ? CHECK_SIZE : READ_NAME; - } - if (!read_global_type && phase == READ_NAME) { - tensor_storage.name = str; - phase = READ_DATA; - tensor_storage.type = global_type; - } - } - } -}; - -ggml_type PickleTensorReader::global_type = GGML_TYPE_F32; // all pickle_tensors data type -bool PickleTensorReader::read_global_type = false; - -int find_char(uint8_t* buffer, int len, char c) { - for (int pos = 0; pos < len; pos++) { - if (buffer[pos] == c) { - return pos; - } - } - return -1; -} - -#define MAX_STRING_BUFFER 512 - -bool ModelLoader::parse_data_pkl(uint8_t* buffer, - size_t buffer_size, - zip_t* zip, - std::string dir, - size_t file_index, - const std::string prefix) { - uint8_t* buffer_end = buffer + buffer_size; - if (buffer[0] == 0x80) { // proto - if (buffer[1] != 2) { - LOG_ERROR("Unsupported protocol\n"); - return false; - } - buffer += 2; // 0x80 and version - char string_buffer[MAX_STRING_BUFFER]; - bool finish = false; - PickleTensorReader reader; - // read pickle binary file - while (!finish && buffer < buffer_end) { - uint8_t opcode = *buffer; - buffer++; - // https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 - // https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 - switch (opcode) { - case '}': // EMPTY_DICT = b'}' # push empty dict - break; - case ']': // EMPTY_LIST = b']' # push empty list - break; - // skip unused sections - case 'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg - case 'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg - case 'Q': // BINPERSID = b'Q' # " " " ; " " " " stack - buffer++; - break; - case 'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg - buffer += 4; - break; - case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame - buffer += 8; - break; - case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo - break; - case '(': // MARK = b'(' # push special markobject on stack - break; - case 'K': // BININT1 = b'K' # push 1-byte unsigned int - { - uint8_t value = *buffer; - if (reader.read_int_value(value)) { - buffer++; - } - buffer++; - } break; - case 'M': // BININT2 = b'M' # push 2-byte unsigned int - { - uint16_t value = read_short(buffer); - if (reader.read_int_value(value)) { - buffer++; - } - buffer += 2; - } break; - case 'J': // BININT = b'J' # push four-byte signed int - { - const int32_t value = read_int(buffer); - if (reader.read_int_value(value)) { - buffer++; // skip tuple after read num_elements - } - buffer += 4; - } break; - case 'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument - { - const int32_t len = read_int(buffer); - buffer += 4; - memset(string_buffer, 0, MAX_STRING_BUFFER); - if (len > MAX_STRING_BUFFER) { - LOG_WARN("tensor name very large"); - } - memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1)); - buffer += len; - reader.read_string(string_buffer, zip, dir); - } break; - case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes - { - const int8_t len = *buffer; - buffer++; - memset(string_buffer, 0, MAX_STRING_BUFFER); - memcpy(string_buffer, buffer, len); - buffer += len; - // printf("String: '%s'\n", string_buffer); - } break; - case 'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args - { - int len = find_char(buffer, MAX_STRING_BUFFER, '\n'); - - buffer += len + 1; - len = find_char(buffer, MAX_STRING_BUFFER, '\n'); - - memset(string_buffer, 0, MAX_STRING_BUFFER); - memcpy(string_buffer, buffer, len); - buffer += len + 1; - reader.read_global(string_buffer); - } break; - case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items - case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top - case 't': // TUPLE = b't' # build tuple from topmost stack items - if (reader.phase == PickleTensorReader::READ_DIMENS) { - reader.tensor_storage.reverse_ne(); - reader.tensor_storage.file_index = file_index; - // if(strcmp(prefix.c_str(), "scarlett") == 0) - // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); - std::string name = reader.tensor_storage.name; - if (!starts_with(name, prefix)) { - name = prefix + name; - } - reader.tensor_storage.name = name; - add_tensor_storage(reader.tensor_storage); - - // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); - // reset - reader = PickleTensorReader(); - } - break; - case '.': // STOP = b'.' # every pickle ends with STOP - finish = true; - break; - default: - break; - } - } - } - return true; -} - -bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) { - LOG_DEBUG("init from '%s'", file_path.c_str()); - file_paths_.push_back(file_path); - size_t file_index = file_paths_.size() - 1; - - zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == nullptr) { - LOG_ERROR("failed to open '%s'", file_path.c_str()); - return false; - } - int n = (int)zip_entries_total(zip); - for (int i = 0; i < n; ++i) { - zip_entry_openbyindex(zip, i); - { - std::string name = zip_entry_name(zip); - size_t pos = name.find("data.pkl"); - if (pos != std::string::npos) { - std::string dir = name.substr(0, pos); - printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str()); - void* pkl_data = nullptr; - size_t pkl_size; - zip_entry_read(zip, &pkl_data, &pkl_size); - - // LOG_DEBUG("%lld", pkl_size); - - parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, file_index, prefix); - - free(pkl_data); - } - } - zip_entry_close(zip); - } - zip_close(zip); - return true; -} - -bool ModelLoader::has_diffusion_model_tensors() -{ - for (auto& [name, tensor_storage] : tensor_storage_map) { - if (tensor_storage.name.find("model.diffusion_model.") != std::string::npos) { - return true; - } - } - return false; -} - SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; @@ -1294,8 +715,8 @@ std::map ModelLoader::get_vae_wtype_stat() { return wtype_stat; } -static std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { - std::vector> result; +TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules) { + TensorTypeRules result; for (const auto& item : split_string(tensor_type_rules, ',')) { if (item.size() == 0) continue; @@ -1432,7 +853,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread return; } } else if (!mmapped) { - file.open(sd_get_u8path(file_path), std::ios::binary); + file.open(file_path, std::ios::binary); if (!file.is_open()) { LOG_ERROR("failed to open '%s'", file_path.c_str()); failed = true; @@ -1728,76 +1149,6 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage return false; } -bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) { - auto backend = ggml_backend_cpu_init(); - size_t mem_size = 1 * 1024 * 1024; // for padding - mem_size += tensor_storage_map.size() * ggml_tensor_overhead(); - mem_size += get_params_mem_size(backend, type); - LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f); - ggml_context* ggml_ctx = ggml_init({mem_size, nullptr, false}); - - gguf_context* gguf_ctx = gguf_init_empty(); - - auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str); - - std::mutex tensor_mutex; - auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { - const std::string& name = tensor_storage.name; - ggml_type tensor_type = tensor_storage.type; - ggml_type dst_type = type; - - for (const auto& tensor_type_rule : tensor_type_rules) { - std::regex pattern(tensor_type_rule.first); - if (std::regex_search(name, pattern)) { - dst_type = tensor_type_rule.second; - break; - } - } - - if (tensor_should_be_converted(tensor_storage, dst_type)) { - tensor_type = dst_type; - } - - std::lock_guard lock(tensor_mutex); - ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); - if (tensor == nullptr) { - LOG_ERROR("ggml_new_tensor failed"); - return false; - } - ggml_set_name(tensor, name.c_str()); - - // LOG_DEBUG("%s %d %s %d[%d %d %d %d] %d[%d %d %d %d]", name.c_str(), - // ggml_nbytes(tensor), ggml_type_name(tensor_type), - // tensor_storage.n_dims, - // tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3], - // tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); - - if (!tensor->data) { - GGML_ASSERT(ggml_nelements(tensor) == 0); - // avoid crashing the gguf writer by setting a dummy pointer for zero-sized tensors - LOG_DEBUG("setting dummy pointer for zero-sized tensor %s", name.c_str()); - tensor->data = ggml_get_mem_buffer(ggml_ctx); - } - - *dst_tensor = tensor; - - gguf_add_tensor(gguf_ctx, tensor); - - return true; - }; - - bool success = load_tensors(on_new_tensor_cb); - ggml_backend_free(backend); - LOG_INFO("load tensors done"); - LOG_INFO("trying to save tensors to %s", file_path.c_str()); - if (success) { - gguf_write_to_file(gguf_ctx, file_path.c_str(), false); - } - ggml_free(ggml_ctx); - gguf_free(gguf_ctx); - return success; -} - int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) { size_t alignment = 128; if (backend != nullptr) { @@ -1817,29 +1168,3 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) return mem_size; } - -bool convert(const char* input_path, - const char* vae_path, - const char* output_path, - sd_type_t output_type, - const char* tensor_type_rules, - bool convert_name) { - ModelLoader model_loader; - - if (!model_loader.init_from_file(input_path)) { - LOG_ERROR("init model loader from file failed: '%s'", input_path); - return false; - } - - if (vae_path != nullptr && strlen(vae_path) > 0) { - if (!model_loader.init_from_file(vae_path, "vae.")) { - LOG_ERROR("init model loader from file failed: '%s'", vae_path); - return false; - } - } - if (convert_name) { - model_loader.convert_tensors_name(); - } - bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); - return success; -} diff --git a/otherarch/sdcpp/model.h b/otherarch/sdcpp/model.h index dcc232aaa..2689f63bd 100644 --- a/otherarch/sdcpp/model.h +++ b/otherarch/sdcpp/model.h @@ -5,20 +5,13 @@ #include #include #include -#include #include -#include -#include #include #include "ggml-backend.h" #include "ggml.h" -#include "gguf.h" -#include "json.hpp" +#include "model_io/tensor_storage.h" #include "ordered_map.hpp" -#include "zip.h" - -#define SD_MAX_DIMS 5 enum SDVersion { VERSION_SD1, @@ -195,116 +188,10 @@ enum PMVersion { PM_VERSION_2, }; -struct TensorStorage { - std::string name; - ggml_type type = GGML_TYPE_F32; - ggml_type expected_type = GGML_TYPE_COUNT; - bool is_f8_e4m3 = false; - bool is_f8_e5m2 = false; - bool is_f64 = false; - bool is_i64 = false; - int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - int n_dims = 0; - - size_t file_index = 0; - int index_in_zip = -1; // >= means stored in a zip file - uint64_t offset = 0; // offset in file - - TensorStorage() = default; - - TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0) - : name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) { - for (int i = 0; i < n_dims; i++) { - this->ne[i] = ne[i]; - } - } - - int64_t nelements() const { - int64_t n = 1; - for (int i = 0; i < SD_MAX_DIMS; i++) { - n *= ne[i]; - } - return n; - } - - int64_t nbytes() const { - return nelements() * ggml_type_size(type) / ggml_blck_size(type); - } - - int64_t nbytes_to_read() const { - if (is_f8_e4m3 || is_f8_e5m2) { - return nbytes() / 2; - } else if (is_f64 || is_i64) { - return nbytes() * 2; - } else { - return nbytes(); - } - } - - void unsqueeze() { - if (n_dims == 2) { - n_dims = 4; - ne[3] = ne[1]; - ne[2] = ne[0]; - ne[1] = 1; - ne[0] = 1; - } - } - - std::vector chunk(size_t n) { - std::vector chunks; - uint64_t chunk_size = nbytes_to_read() / n; - // printf("%d/%d\n", chunk_size, nbytes_to_read()); - reverse_ne(); - for (size_t i = 0; i < n; i++) { - TensorStorage chunk_i = *this; - chunk_i.ne[0] = ne[0] / n; - chunk_i.offset = offset + i * chunk_size; - chunk_i.reverse_ne(); - chunks.push_back(chunk_i); - } - reverse_ne(); - return chunks; - } - - void reverse_ne() { - int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; - for (int i = 0; i < n_dims; i++) { - new_ne[i] = ne[n_dims - 1 - i]; - } - for (int i = 0; i < n_dims; i++) { - ne[i] = new_ne[i]; - } - } - - std::string to_string() const { - std::stringstream ss; - const char* type_name = ggml_type_name(type); - if (is_f8_e4m3) { - type_name = "f8_e4m3"; - } else if (is_f8_e5m2) { - type_name = "f8_e5m2"; - } else if (is_f64) { - type_name = "f64"; - } else if (is_i64) { - type_name = "i64"; - } - ss << name << " | " << type_name << " | "; - ss << n_dims << " ["; - for (int i = 0; i < SD_MAX_DIMS; i++) { - ss << ne[i]; - if (i != SD_MAX_DIMS - 1) { - ss << ", "; - } - } - ss << "]"; - return ss.str(); - } -}; - -typedef std::function on_new_tensor_cb_t; - typedef OrderedMap String2TensorStorage; +using TensorTypeRules = std::vector>; + +TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); class ModelLoader { protected: @@ -314,16 +201,10 @@ protected: void add_tensor_storage(const TensorStorage& tensor_storage); - bool parse_data_pkl(uint8_t* buffer, - size_t buffer_size, - zip_t* zip, - std::string dir, - size_t file_index, - const std::string prefix); - bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); - bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: @@ -354,7 +235,6 @@ public: return names; } - bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; diff --git a/otherarch/sdcpp/model_io/binary_io.h b/otherarch/sdcpp/model_io/binary_io.h new file mode 100644 index 000000000..9093eeaf9 --- /dev/null +++ b/otherarch/sdcpp/model_io/binary_io.h @@ -0,0 +1,57 @@ +#ifndef __SD_MODEL_IO_BINARY_IO_H__ +#define __SD_MODEL_IO_BINARY_IO_H__ + +#include +#include + +namespace model_io { + + inline int32_t read_int(const uint8_t* buffer) { + uint32_t value = 0; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return static_cast(value); + } + + inline uint16_t read_short(const uint8_t* buffer) { + uint16_t value = 0; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; + } + + inline uint64_t read_u64(const uint8_t* buffer) { + uint64_t value = 0; + value |= static_cast(buffer[7]) << 56; + value |= static_cast(buffer[6]) << 48; + value |= static_cast(buffer[5]) << 40; + value |= static_cast(buffer[4]) << 32; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; + } + + inline void write_u64(std::ostream& stream, uint64_t value) { + uint8_t buffer[8]; + for (int i = 0; i < 8; ++i) { + buffer[i] = static_cast((value >> (8 * i)) & 0xFF); + } + stream.write((const char*)buffer, sizeof(buffer)); + } + + inline int find_char(const uint8_t* buffer, int len, char c) { + for (int pos = 0; pos < len; pos++) { + if (buffer[pos] == (uint8_t)c) { + return pos; + } + } + return -1; + } + +} // namespace model_io + +#endif // __SD_MODEL_IO_BINARY_IO_H__ diff --git a/otherarch/sdcpp/model_io/gguf_io.cpp b/otherarch/sdcpp/model_io/gguf_io.cpp new file mode 100644 index 000000000..378694d8e --- /dev/null +++ b/otherarch/sdcpp/model_io/gguf_io.cpp @@ -0,0 +1,123 @@ +#include "gguf_io.h" + +#include +#include +#include +#include + +#include "gguf.h" +#include "gguf_reader_ext.h" +#include "util.h" + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_gguf_file(const std::string& file_path) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + char magic[4]; + + file.read(magic, sizeof(magic)); + if (!file) { + return false; + } + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + return false; + } + } + + return true; +} + +bool read_gguf_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + tensor_storages.clear(); + + gguf_context* ctx_gguf_ = nullptr; + ggml_context* ctx_meta_ = nullptr; + + ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_}); + if (!ctx_gguf_) { + GGUFReader gguf_reader; + if (!gguf_reader.load(file_path)) { + set_error(error, "failed to open '" + file_path + "' with GGUFReader"); + return false; + } + + size_t data_offset = gguf_reader.data_offset(); + for (const auto& gguf_tensor_info : gguf_reader.tensors()) { + TensorStorage tensor_storage( + gguf_tensor_info.name, + gguf_tensor_info.type, + gguf_tensor_info.shape.data(), + static_cast(gguf_tensor_info.shape.size()), + 0, + data_offset + gguf_tensor_info.offset); + + tensor_storages.push_back(tensor_storage); + } + + return true; + } + + int n_tensors = static_cast(gguf_get_n_tensors(ctx_gguf_)); + + size_t data_offset = gguf_get_data_offset(ctx_gguf_); + for (int i = 0; i < n_tensors; i++) { + std::string name = gguf_get_tensor_name(ctx_gguf_, i); + ggml_tensor* dummy = ggml_get_tensor(ctx_meta_, name.c_str()); + size_t offset = data_offset + gguf_get_tensor_offset(ctx_gguf_, i); + + TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), 0, offset); + + if (ggml_nbytes(dummy) != tensor_storage.nbytes()) { + gguf_free(ctx_gguf_); + ggml_free(ctx_meta_); + set_error(error, "size mismatch for tensor '" + name + "'"); + return false; + } + + tensor_storages.push_back(tensor_storage); + } + + gguf_free(ctx_gguf_); + ggml_free(ctx_meta_); + + return true; +} + +bool write_gguf_file(const std::string& file_path, + const std::vector& tensors, + std::string* error) { + gguf_context* gguf_ctx = gguf_init_empty(); + if (gguf_ctx == nullptr) { + set_error(error, "gguf_init_empty failed"); + return false; + } + + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + if (tensor == nullptr) { + set_error(error, "null tensor cannot be written to GGUF"); + gguf_free(gguf_ctx); + return false; + } + gguf_add_tensor(gguf_ctx, tensor); + } + + LOG_INFO("trying to save tensors to %s", file_path.c_str()); + bool success = gguf_write_to_file(gguf_ctx, file_path.c_str(), false); + if (!success) { + set_error(error, "failed to write GGUF file '" + file_path + "'"); + } + gguf_free(gguf_ctx); + return success; +} diff --git a/otherarch/sdcpp/model_io/gguf_io.h b/otherarch/sdcpp/model_io/gguf_io.h new file mode 100644 index 000000000..81c981145 --- /dev/null +++ b/otherarch/sdcpp/model_io/gguf_io.h @@ -0,0 +1,17 @@ +#ifndef __SD_MODEL_IO_GGUF_IO_H__ +#define __SD_MODEL_IO_GGUF_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_gguf_file(const std::string& file_path); +bool read_gguf_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); +bool write_gguf_file(const std::string& file_path, + const std::vector& tensors, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_GGUF_IO_H__ diff --git a/otherarch/sdcpp/gguf_reader.hpp b/otherarch/sdcpp/model_io/gguf_reader_ext.h similarity index 98% rename from otherarch/sdcpp/gguf_reader.hpp rename to otherarch/sdcpp/model_io/gguf_reader_ext.h index 21ca17e3e..60cde37cc 100644 --- a/otherarch/sdcpp/gguf_reader.hpp +++ b/otherarch/sdcpp/model_io/gguf_reader_ext.h @@ -1,5 +1,5 @@ -#ifndef __GGUF_READER_HPP__ -#define __GGUF_READER_HPP__ +#ifndef __SD_MODEL_IO_GGUF_READER_EXT_H__ +#define __SD_MODEL_IO_GGUF_READER_EXT_H__ #include #include @@ -231,4 +231,4 @@ public: size_t data_offset() const { return data_offset_; } }; -#endif // __GGUF_READER_HPP__ +#endif // __SD_MODEL_IO_GGUF_READER_EXT_H__ diff --git a/otherarch/sdcpp/model_io/pickle_io.cpp b/otherarch/sdcpp/model_io/pickle_io.cpp new file mode 100644 index 000000000..3a978178a --- /dev/null +++ b/otherarch/sdcpp/model_io/pickle_io.cpp @@ -0,0 +1,1064 @@ +#include "pickle_io.h" + +#include +#include +#include +#include +#include +#include + +#include "binary_io.h" +#include "util.h" + +// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 +// 0: \x80 PROTO 2 +// 2: } EMPTY_DICT +// 3: q BINPUT 0 +// 5: ( MARK +// 6: X BINUNICODE 'epoch' +// 16: q BINPUT 1 +// 18: K BININT1 6 +// 20: X BINUNICODE 'global_step' +// 36: q BINPUT 2 +// 38: J BININT 470000 +// 43: X BINUNICODE 'pytorch-lightning_version' +// 73: q BINPUT 3 +// 75: X BINUNICODE '1.4.2' +// 85: q BINPUT 4 +// 87: X BINUNICODE 'state_dict' +// 102: q BINPUT 5 +// 104: } EMPTY_DICT +// 105: q BINPUT 6 +// 107: ( MARK +// 108: X BINUNICODE 'betas' +// 118: q BINPUT 7 +// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' +// 153: q BINPUT 8 +// 155: ( MARK +// 156: ( MARK +// 157: X BINUNICODE 'storage' +// 169: q BINPUT 9 +// 171: c GLOBAL 'torch FloatStorage' +// 191: q BINPUT 10 +// 193: X BINUNICODE '0' +// 199: q BINPUT 11 +// 201: X BINUNICODE 'cpu' +// 209: q BINPUT 12 +// 211: M BININT2 1000 +// 214: t TUPLE (MARK at 156) +// 215: q BINPUT 13 +// 217: Q BINPERSID +// 218: K BININT1 0 +// 220: M BININT2 1000 +// ............................... +// 3201: q BINPUT 250 +// 3203: R REDUCE +// 3204: q BINPUT 251 +// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' +// 3264: q BINPUT 252 +// 3266: h BINGET 8 +// 3268: ( MARK +// 3269: ( MARK +// 3270: h BINGET 9 +// 3272: h BINGET 10 +// 3274: X BINUNICODE '30' +// 3281: q BINPUT 253 +// 3283: h BINGET 12 +// 3285: J BININT 102400 +// 3290: t TUPLE (MARK at 3269) +// 3291: q BINPUT 254 +// 3293: Q BINPERSID +// 3294: K BININT1 0 +// 3296: ( MARK +// 3297: M BININT2 320 +// 3300: M BININT2 320 +// 3303: K BININT1 1 +// 3305: K BININT1 1 +// 3307: t TUPLE (MARK at 3296) +// 3308: q BINPUT 255 +// 3310: ( MARK +// 3311: M BININT2 320 +// 3314: K BININT1 1 +// 3316: K BININT1 1 +// 3318: K BININT1 1 +// 3320: t TUPLE (MARK at 3310) +// 3321: r LONG_BINPUT 256 +// 3326: \x89 NEWFALSE +// 3327: h BINGET 16 +// 3329: ) EMPTY_TUPLE +// 3330: R REDUCE +// 3331: r LONG_BINPUT 257 +// 3336: t TUPLE (MARK at 3268) +// 3337: r LONG_BINPUT 258 +// 3342: R REDUCE +// 3343: r LONG_BINPUT 259 +// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' +// 3404: r LONG_BINPUT 260 +// 3409: h BINGET 8 +// 3411: ( MARK +// 3412: ( MARK +// 3413: h BINGET 9 +// 3415: h BINGET 10 +// 3417: X BINUNICODE '31' +// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 +// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 + +using model_io::find_char; +using model_io::read_int; +using model_io::read_short; +using model_io::read_u64; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size) { + const uint8_t* p = buffer; + const uint8_t* end = buffer + buffer_size; + + while (p < end) { + uint8_t opcode = *p++; + switch (opcode) { + case '.': // STOP = b'.' # every pickle ends with STOP + *object_size = (size_t)(p - buffer); + return true; + case 0x80: // PROTO = b'\x80' # protocol version indicator + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + case 'h': // BINGET = b'h' # read memo index, 1-byte arg + case 'q': // BINPUT = b'q' # write memo index, 1-byte arg + case 'C': // SHORT_BINBYTES = b'C' # push bytes; length < 256 + case 0x82: // EXT1 = b'\x82' # extension code, 1-byte arg + p += 1; + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + case 0x83: // EXT2 = b'\x83' # extension code, 2-byte arg + p += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + case 'j': // LONG_BINGET = b'j' # read memo index, 4-byte arg + case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg + case 0x84: // EXT4 = b'\x84' # extension code, 4-byte arg + p += 4; + break; + case 'I': // INT = b'I' # push decimal integer line + case 'L': // LONG = b'L' # push decimal long integer line + case 'F': // FLOAT = b'F' # push decimal float line + case 'S': // STRING = b'S' # push quoted string line + case 'V': { // UNICODE = b'V' # push raw-unicode string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case 'G': // BINFLOAT = b'G' # push 8-byte binary float + p += 8; + break; + case 0x8A: // LONG1 = b'\x8a' # push long integer; 1-byte length + if (p >= end) { + return false; + } + p += 1 + p[0]; + break; + case 0x8B: { // LONG4 = b'\x8b' # push long integer; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 'T': // BINSTRING = b'T' # push string; 4-byte length + case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length + case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length + case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length + if (p + 8 > end) { + return false; + } + uint64_t n = read_u64(p); + p += 8; + if (n > (uint64_t)(end - p)) { + return false; + } + p += n; + } break; + case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 + case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 + if (p >= end) { + return false; + } + p += 1 + p[0]; + break; + case 'P': { // PERSID = b'P' # persistent id, newline-terminated + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame + p += 8; + break; + case 'c': { // GLOBAL = b'c' # push module/name global reference + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case '}': // EMPTY_DICT = b'}' # push empty dict + case ']': // EMPTY_LIST = b']' # push empty list + case '(': // MARK = b'(' # push markobject + case 't': // TUPLE = b't' # build tuple from mark + case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack + case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack + case 0x87: // TUPLE3 = b'\x87' # build 3-tuple from stack + case ')': // EMPTY_TUPLE = b')' # push empty tuple + case 'l': // LIST = b'l' # build list from mark + case 'Q': // BINPERSID = b'Q' # persistent id from stack + case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo + case 0x88: // NEWTRUE = b'\x88' # push True + case 0x89: // NEWFALSE = b'\x89' # push False + case 'R': // REDUCE = b'R' # apply callable to args + case 'u': // SETITEMS = b'u' # add mark-delimited items to dict + case 's': // SETITEM = b's' # add key/value to dict + case 'e': // APPENDS = b'e' # extend list with mark-delimited items + case 'a': // APPEND = b'a' # append item to list + case 'b': // BUILD = b'b' # build object state + case 0x81: // NEWOBJ = b'\x81' # build object via __new__ + case 0x8F: // EMPTY_SET = b'\x8f' # push empty set + case 0x90: // ADDITEMS = b'\x90' # add mark-delimited items to set + case 0x91: // FROZENSET = b'\x91' # build frozenset from mark + case 0x92: // NEWOBJ_EX = b'\x92' # build object with kwargs + case 0x93: // STACK_GLOBAL = b'\x93' # build global from module/name strings + case 0x97: // NEXT_BUFFER = b'\x97' # out-of-band buffer marker + case 0x98: // READONLY_BUFFER = b'\x98' # mark buffer readonly + case 'N': // NONE = b'N' # push None + case '0': // POP = b'0' # discard top stack item + case '1': // POP_MARK = b'1' # discard stack through topmost mark + case '2': // DUP = b'2' # duplicate top stack item + case 'o': // OBJ = b'o' # build class instance from mark + break; + case 'i': { // INST = b'i' # build class instance from module/name + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + default: + return false; + } + if (p > end) { + return false; + } + } + + return false; +} + +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size) { + static const uint8_t torch_magic_bytes[] = {0x6C, 0xFC, 0x9C, 0x46, 0xF9, 0x20, 0x6A, 0xA8, 0x50, 0x19}; + + if (buffer_size < 5 || buffer[0] != 0x80) { + return false; + } + + size_t pos = 2; + if (pos >= buffer_size) { + return false; + } + + uint8_t opcode = buffer[pos++]; + if (opcode != 0x8A || pos >= buffer_size) { + return false; + } + + uint8_t len = buffer[pos++]; + if (len != sizeof(torch_magic_bytes) || pos + len >= buffer_size) { + return false; + } + + if (memcmp(buffer + pos, torch_magic_bytes, sizeof(torch_magic_bytes)) != 0) { + return false; + } + pos += len; + + return pos < buffer_size && buffer[pos] == '.'; +} + +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value) { + if (buffer_size < 4 || buffer[0] != 0x80) { + return false; + } + + size_t pos = 2; + if (pos >= buffer_size) { + return false; + } + + uint8_t opcode = buffer[pos++]; + switch (opcode) { + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + if (pos + 1 >= buffer_size) { + return false; + } + *value = buffer[pos]; + pos += 1; + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + if (pos + 2 >= buffer_size) { + return false; + } + *value = read_short(buffer + pos); + pos += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + if (pos + 4 >= buffer_size) { + return false; + } + *value = (uint32_t)read_int(buffer + pos); + pos += 4; + break; + default: + return false; + } + + return pos < buffer_size && buffer[pos] == '.'; +} + +struct PickleStorageInfo { + std::string key; + ggml_type type = GGML_TYPE_COUNT; + bool is_f64 = false; + bool is_i64 = false; + uint64_t raw_element_nbytes = 0; + uint64_t nbytes = 0; +}; + +struct PickleTensorInfo { + TensorStorage tensor_storage; + int stride_n_dims = 0; + int64_t stride[SD_MAX_DIMS]{1, 1, 1, 1, 1}; +}; + +struct PickleValue { + enum Kind { + MARK, + NONE, + BOOL, + INT, + STRING, + GLOBAL, + TUPLE, + LIST, + DICT, + ORDERED_DICT, + STORAGE, + TENSOR, + }; + + Kind kind = NONE; + int64_t int_value = 0; + bool bool_value = false; + std::string str_value; + std::vector items; + std::vector> dict_items; + PickleStorageInfo storage; + PickleTensorInfo tensor; +}; + +static PickleValue make_mark_value() { + PickleValue value; + value.kind = PickleValue::MARK; + return value; +} + +static PickleValue make_none_value() { + PickleValue value; + value.kind = PickleValue::NONE; + return value; +} + +static PickleValue make_bool_value(bool b) { + PickleValue value; + value.kind = PickleValue::BOOL; + value.bool_value = b; + return value; +} + +static PickleValue make_int_value(int64_t x) { + PickleValue value; + value.kind = PickleValue::INT; + value.int_value = x; + return value; +} + +static PickleValue make_string_value(const std::string& s) { + PickleValue value; + value.kind = PickleValue::STRING; + value.str_value = s; + return value; +} + +static PickleValue make_global_value(const std::string& s) { + PickleValue value; + value.kind = PickleValue::GLOBAL; + value.str_value = s; + return value; +} + +static PickleValue make_tuple_value(std::vector items) { + PickleValue value; + value.kind = PickleValue::TUPLE; + value.items = std::move(items); + return value; +} + +static PickleValue make_list_value() { + PickleValue value; + value.kind = PickleValue::LIST; + return value; +} + +static PickleValue make_dict_value(bool ordered) { + PickleValue value; + value.kind = ordered ? PickleValue::ORDERED_DICT : PickleValue::DICT; + return value; +} + +static PickleValue make_storage_value(const PickleStorageInfo& storage) { + PickleValue value; + value.kind = PickleValue::STORAGE; + value.storage = storage; + return value; +} + +static PickleValue make_tensor_value(const PickleTensorInfo& tensor) { + PickleValue value; + value.kind = PickleValue::TENSOR; + value.tensor = tensor; + return value; +} + +static std::string pickle_value_to_string(const PickleValue& value) { + if (value.kind == PickleValue::STRING) { + return value.str_value; + } + if (value.kind == PickleValue::INT) { + return std::to_string(value.int_value); + } + return ""; +} + +static bool parse_storage_type(const std::string& global_name, PickleStorageInfo* storage) { + if (global_name == "torch.FloatStorage") { + storage->type = GGML_TYPE_F32; + storage->raw_element_nbytes = 4; + return true; + } + if (global_name == "torch.DoubleStorage") { + storage->type = GGML_TYPE_F32; + storage->is_f64 = true; + storage->raw_element_nbytes = 8; + return true; + } + if (global_name == "torch.HalfStorage") { + storage->type = GGML_TYPE_F16; + storage->raw_element_nbytes = 2; + return true; + } + if (global_name == "torch.BFloat16Storage") { + storage->type = GGML_TYPE_BF16; + storage->raw_element_nbytes = 2; + return true; + } + if (global_name == "torch.IntStorage") { + storage->type = GGML_TYPE_I32; + storage->raw_element_nbytes = 4; + return true; + } + if (global_name == "torch.LongStorage") { + storage->type = GGML_TYPE_I32; + storage->is_i64 = true; + storage->raw_element_nbytes = 8; + return true; + } + return false; +} + +static bool tensor_is_contiguous(const PickleTensorInfo& tensor) { + if (tensor.tensor_storage.nelements() == 0) { + return true; + } + if (tensor.stride_n_dims != tensor.tensor_storage.n_dims) { + return false; + } + + int64_t expected_stride = 1; + for (int i = tensor.tensor_storage.n_dims - 1; i >= 0; --i) { + if (tensor.stride[i] != expected_stride) { + return false; + } + expected_stride *= tensor.tensor_storage.ne[i]; + } + return true; +} + +static void collect_tensors_from_pickle_value(const PickleValue& value, + std::vector& tensor_storages) { + if (value.kind != PickleValue::DICT && value.kind != PickleValue::ORDERED_DICT) { + return; + } + + for (const auto& item : value.dict_items) { + if (item.first.kind == PickleValue::STRING && item.second.kind == PickleValue::TENSOR) { + TensorStorage tensor_storage = item.second.tensor.tensor_storage; + tensor_storage.name = item.first.str_value; + tensor_storage.reverse_ne(); + tensor_storages.push_back(tensor_storage); + } else if (item.second.kind == PickleValue::DICT || item.second.kind == PickleValue::ORDERED_DICT) { + collect_tensors_from_pickle_value(item.second, tensor_storages); + } + } +} + +bool parse_torch_state_dict_pickle(const uint8_t* buffer, + size_t buffer_size, + std::vector& tensor_storages, + std::unordered_map& storage_nbytes, + std::string* error) { + if (buffer_size < 2 || buffer[0] != 0x80 || buffer[1] < 2 || buffer[1] > 5) { + set_error(error, "unsupported torch pickle protocol"); + return false; + } + + const uint8_t* p = buffer + 2; + const uint8_t* end = buffer + buffer_size; + std::vector stack; + std::unordered_map memo; + + while (p < end) { + uint8_t opcode = *p++; + switch (opcode) { + case '.': { // STOP = b'.' # every pickle ends with STOP + if (stack.empty()) { + set_error(error, "empty torch pickle stack"); + return false; + } + size_t old_tensor_count = tensor_storages.size(); + collect_tensors_from_pickle_value(stack.back(), tensor_storages); + if (tensor_storages.size() == old_tensor_count) { + set_error(error, "torch pickle does not contain a supported state_dict"); + return false; + } + return true; + } + case '}': // EMPTY_DICT = b'}' # push empty dict + stack.push_back(make_dict_value(false)); + break; + case ']': // EMPTY_LIST = b']' # push empty list + stack.push_back(make_list_value()); + break; + case 'l': { // LIST = b'l' # build list from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + set_error(error, "torch pickle list without mark"); + return false; + } + std::vector items(stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + PickleValue list_value = make_list_value(); + list_value.items = std::move(items); + stack.push_back(std::move(list_value)); + } break; + case '(': // MARK = b'(' # push markobject + stack.push_back(make_mark_value()); + break; + case ')': // EMPTY_TUPLE = b')' # push empty tuple + stack.push_back(make_tuple_value({})); + break; + case 'N': // NONE = b'N' # push None + stack.push_back(make_none_value()); + break; + case 0x88: // NEWTRUE = b'\x88' # push True + stack.push_back(make_bool_value(true)); + break; + case 0x89: // NEWFALSE = b'\x89' # push False + stack.push_back(make_bool_value(false)); + break; + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + if (p >= end) { + return false; + } + stack.push_back(make_int_value(*p++)); + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + if (p + 2 > end) { + return false; + } + stack.push_back(make_int_value(read_short(p))); + p += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + if (p + 4 > end) { + return false; + } + stack.push_back(make_int_value(read_int(p))); + p += 4; + break; + case 'I': { // INT = b'I' # push decimal integer line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (s == "01") { + stack.push_back(make_bool_value(true)); + } else if (s == "00") { + stack.push_back(make_bool_value(false)); + } else { + stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); + } + } break; + case 'L': { // LONG = b'L' # push decimal long integer line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (!s.empty() && s.back() == 'L') { + s.pop_back(); + } + stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); + } break; + case 'F': { // FLOAT = b'F' # push decimal float line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + stack.push_back(make_none_value()); + } break; + case 'G': // BINFLOAT = b'G' # push 8-byte binary float + if (p + 8 > end) { + return false; + } + p += 8; + stack.push_back(make_none_value()); + break; + case 0x8A: { // LONG1 = b'\x8a' # push long integer; 1-byte length + if (p >= end) { + return false; + } + uint8_t n = *p++; + if (p + n > end || n > 8) { + return false; + } + int64_t value = 0; + for (uint8_t i = 0; i < n; ++i) { + value |= (int64_t)p[i] << (i * 8); + } + p += n; + stack.push_back(make_int_value(value)); + } break; + case 'C': { // SHORT_BINBYTES = b'C' # push bytes; length < 256 + if (p >= end) { + return false; + } + uint8_t len = *p++; + if (p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length + if (p + 4 > end) { + return false; + } + int32_t len = read_int(p); + p += 4; + if (len < 0 || p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'T': // BINSTRING = b'T' # push string; 4-byte length + case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length + if (p + 4 > end) { + return false; + } + int32_t len = read_int(p); + p += 4; + if (len < 0 || p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length + case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length + case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length + if (p + 8 > end) { + return false; + } + uint64_t len = read_u64(p); + p += 8; + if (len > (uint64_t)(end - p)) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, (size_t)len))); + p += len; + } break; + case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 + case 0x8C: { // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 + if (p >= end) { + return false; + } + uint8_t len = *p++; + if (p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'S': { // STRING = b'S' # push quoted string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (s.size() >= 2 && (s[0] == '\'' || s[0] == '"') && s.back() == s[0]) { + s = s.substr(1, s.size() - 2); + } + stack.push_back(make_string_value(s)); + } break; + case 'V': { // UNICODE = b'V' # push raw-unicode string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len + 1; + } break; + case 'c': { // GLOBAL = b'c' # push module/name global reference + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string module((const char*)p, len); + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string name((const char*)p, len); + p += len + 1; + stack.push_back(make_global_value(module + "." + name)); + } break; + case 0x93: { // STACK_GLOBAL = b'\x93' # build global from module/name strings + if (stack.size() < 2 || stack[stack.size() - 2].kind != PickleValue::STRING || + stack.back().kind != PickleValue::STRING) { + return false; + } + std::string name = stack.back().str_value; + stack.pop_back(); + std::string module = stack.back().str_value; + stack.pop_back(); + stack.push_back(make_global_value(module + "." + name)); + } break; + case 'h': // BINGET = b'h' # read memo index, 1-byte arg + if (p >= end || !memo.count(*p)) { + return false; + } + stack.push_back(memo[*p++]); + break; + case 'j': { // LONG_BINGET = b'j' # read memo index, 4-byte arg + if (p + 4 > end) { + return false; + } + int32_t memo_idx = read_int(p); + if (!memo.count(memo_idx)) { + return false; + } + stack.push_back(memo[memo_idx]); + p += 4; + } break; + case 'q': // BINPUT = b'q' # write memo index, 1-byte arg + if (p >= end || stack.empty()) { + return false; + } + memo[*p++] = stack.back(); + break; + case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg + if (p + 4 > end || stack.empty()) { + return false; + } + memo[read_int(p)] = stack.back(); + p += 4; + break; + case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo + if (stack.empty()) { + return false; + } + memo[(int32_t)memo.size()] = stack.back(); + break; + case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame + if (p + 8 > end) { + return false; + } + p += 8; + break; + case '0': // POP = b'0' # discard top stack item + if (stack.empty()) { + return false; + } + stack.pop_back(); + break; + case '1': { // POP_MARK = b'1' # discard stack through topmost mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case '2': // DUP = b'2' # duplicate top stack item + if (stack.empty()) { + return false; + } + stack.push_back(stack.back()); + break; + case 0x8F: // EMPTY_SET = b'\x8f' # push empty set + stack.push_back(make_list_value()); + break; + case 0x90: { // ADDITEMS = b'\x90' # add mark-delimited items to set + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { + return false; + } + PickleValue& set_value = stack[mark_idx - 1]; + set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 0x91: { // FROZENSET = b'\x91' # build frozenset from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + PickleValue set_value = make_list_value(); + set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + stack.push_back(std::move(set_value)); + } break; + case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack + case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack + case 0x87: { // TUPLE3 = b'\x87' # build 3-tuple from stack + int tuple_size = opcode == 0x85 ? 1 : (opcode == 0x86 ? 2 : 3); + if ((int)stack.size() < tuple_size) { + return false; + } + std::vector items(stack.end() - tuple_size, stack.end()); + stack.erase(stack.end() - tuple_size, stack.end()); + stack.push_back(make_tuple_value(std::move(items))); + } break; + case 't': { // TUPLE = b't' # build tuple from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + std::vector items(stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + stack.push_back(make_tuple_value(std::move(items))); + } break; + case 'Q': { // BINPERSID = b'Q' # persistent id from stack + if (stack.empty()) { + return false; + } + PickleValue pid = stack.back(); + stack.pop_back(); + if (pid.kind != PickleValue::TUPLE || pid.items.size() < 5 || pid.items[0].kind != PickleValue::STRING || + pid.items[1].kind != PickleValue::GLOBAL || pid.items[4].kind != PickleValue::INT || + pid.items[0].str_value != "storage") { + return false; + } + + PickleStorageInfo storage; + storage.key = pickle_value_to_string(pid.items[2]); + if (storage.key.empty() || !parse_storage_type(pid.items[1].str_value, &storage)) { + return false; + } + storage.nbytes = (uint64_t)pid.items[4].int_value * storage.raw_element_nbytes; + storage_nbytes[storage.key] = storage.nbytes; + stack.push_back(make_storage_value(storage)); + } break; + case 'R': { // REDUCE = b'R' # apply callable to args + if (stack.size() < 2) { + return false; + } + PickleValue args = stack.back(); + stack.pop_back(); + PickleValue callable = stack.back(); + stack.pop_back(); + if (callable.kind != PickleValue::GLOBAL || args.kind != PickleValue::TUPLE) { + stack.push_back(make_none_value()); + break; + } + + if (callable.str_value == "collections.OrderedDict" && args.items.empty()) { + stack.push_back(make_dict_value(true)); + break; + } + + if ((callable.str_value == "torch._utils._rebuild_tensor_v2" || callable.str_value == "torch._utils._rebuild_tensor") && + args.items.size() >= 4 && args.items[0].kind == PickleValue::STORAGE && + args.items[1].kind == PickleValue::INT && args.items[2].kind == PickleValue::TUPLE && + args.items[3].kind == PickleValue::TUPLE) { + PickleTensorInfo tensor; + tensor.tensor_storage.type = args.items[0].storage.type; + tensor.tensor_storage.is_f64 = args.items[0].storage.is_f64; + tensor.tensor_storage.is_i64 = args.items[0].storage.is_i64; + tensor.tensor_storage.storage_key = args.items[0].storage.key; + tensor.tensor_storage.offset = (uint64_t)args.items[1].int_value * args.items[0].storage.raw_element_nbytes; + + for (const auto& item : args.items[2].items) { + if (item.kind != PickleValue::INT || tensor.tensor_storage.n_dims >= SD_MAX_DIMS) { + return false; + } + tensor.tensor_storage.ne[tensor.tensor_storage.n_dims++] = item.int_value; + } + + for (const auto& item : args.items[3].items) { + if (item.kind != PickleValue::INT || tensor.stride_n_dims >= SD_MAX_DIMS) { + return false; + } + tensor.stride[tensor.stride_n_dims++] = item.int_value; + } + + if (!tensor_is_contiguous(tensor)) { + return false; + } + stack.push_back(make_tensor_value(tensor)); + break; + } + + // Non-tensor checkpoint metadata can use REDUCE for arbitrary + // Python objects. Do not execute it; keep stack shape only. + stack.push_back(make_none_value()); + break; + } + case 'b': // BUILD = b'b' # build object state + if (stack.size() < 2) { + return false; + } + stack.pop_back(); + break; + case 'u': { // SETITEMS = b'u' # add mark-delimited items to dict + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0) { + return false; + } + PickleValue& dict = stack[mark_idx - 1]; + if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { + return false; + } + for (int i = mark_idx + 1; i + 1 < (int)stack.size(); i += 2) { + dict.dict_items.emplace_back(stack[i], stack[i + 1]); + } + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 's': { // SETITEM = b's' # add key/value to dict + if (stack.size() < 3) { + return false; + } + PickleValue value = stack.back(); + stack.pop_back(); + PickleValue key = stack.back(); + stack.pop_back(); + PickleValue& dict = stack.back(); + if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { + return false; + } + dict.dict_items.emplace_back(key, value); + } break; + case 'e': { // APPENDS = b'e' # extend list with mark-delimited items + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { + return false; + } + PickleValue& list_value = stack[mark_idx - 1]; + list_value.items.insert(list_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 'a': { // APPEND = b'a' # append item to list + if (stack.size() < 2) { + return false; + } + PickleValue item = stack.back(); + stack.pop_back(); + if (stack.back().kind != PickleValue::LIST) { + return false; + } + stack.back().items.push_back(item); + } break; + default: + set_error(error, + "unsupported torch pickle opcode 0x" + sd_format("%02X", opcode) + + " at offset " + std::to_string((p - buffer) - 1)); + return false; + } + } + + set_error(error, "unterminated torch state_dict pickle"); + return false; +} diff --git a/otherarch/sdcpp/model_io/pickle_io.h b/otherarch/sdcpp/model_io/pickle_io.h new file mode 100644 index 000000000..6a3db37b9 --- /dev/null +++ b/otherarch/sdcpp/model_io/pickle_io.h @@ -0,0 +1,21 @@ +#ifndef __SD_MODEL_IO_PICKLE_IO_H__ +#define __SD_MODEL_IO_PICKLE_IO_H__ + +#include +#include +#include +#include +#include + +#include "tensor_storage.h" + +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size); +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size); +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value); +bool parse_torch_state_dict_pickle(const uint8_t* buffer, + size_t buffer_size, + std::vector& tensor_storages, + std::unordered_map& storage_nbytes, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_PICKLE_IO_H__ diff --git a/otherarch/sdcpp/model_io/safetensors_io.cpp b/otherarch/sdcpp/model_io/safetensors_io.cpp new file mode 100644 index 000000000..889352218 --- /dev/null +++ b/otherarch/sdcpp/model_io/safetensors_io.cpp @@ -0,0 +1,316 @@ +#include "safetensors_io.h" + +#include +#include +#include +#include +#include + +#include "binary_io.h" +#include "json.hpp" +#include "util.h" + +static constexpr size_t ST_HEADER_SIZE_LEN = 8; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_safetensors_file(const std::string& file_path) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + return false; + } + + // get file size + file.seekg(0, file.end); + size_t file_size_ = file.tellg(); + file.seekg(0, file.beg); + + // read header size + if (file_size_ <= ST_HEADER_SIZE_LEN) { + return false; + } + + uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; + file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); + if (!file) { + return false; + } + + size_t header_size_ = model_io::read_u64(header_size_buf); + if (header_size_ >= file_size_ || header_size_ <= 2) { + return false; + } + + // read header + std::vector header_buf; + header_buf.resize(header_size_ + 1); + header_buf[header_size_] = '\0'; + file.read(header_buf.data(), header_size_); + if (!file) { + return false; + } + try { + nlohmann::json header_ = nlohmann::json::parse(header_buf.data()); + } catch (const std::exception&) { + return false; + } + return true; +} + +static ggml_type safetensors_dtype_to_ggml_type(const std::string& dtype) { + ggml_type ttype = GGML_TYPE_COUNT; + if (dtype == "F16") { + ttype = GGML_TYPE_F16; + } else if (dtype == "BF16") { + ttype = GGML_TYPE_BF16; + } else if (dtype == "F32") { + ttype = GGML_TYPE_F32; + } else if (dtype == "F64") { + ttype = GGML_TYPE_F32; + } else if (dtype == "F8_E4M3") { + ttype = GGML_TYPE_F16; + } else if (dtype == "F8_E5M2") { + ttype = GGML_TYPE_F16; + } else if (dtype == "I32") { + ttype = GGML_TYPE_I32; + } else if (dtype == "I64") { + ttype = GGML_TYPE_I32; + } + return ttype; +} + +// https://huggingface.co/docs/safetensors/index +bool read_safetensors_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + // get file size + file.seekg(0, file.end); + size_t file_size_ = file.tellg(); + file.seekg(0, file.beg); + + // read header size + if (file_size_ <= ST_HEADER_SIZE_LEN) { + set_error(error, "invalid safetensor file '" + file_path + "'"); + return false; + } + + uint8_t header_size_buf[ST_HEADER_SIZE_LEN]; + file.read((char*)header_size_buf, ST_HEADER_SIZE_LEN); + if (!file) { + set_error(error, "read safetensors header size failed: '" + file_path + "'"); + return false; + } + + size_t header_size_ = model_io::read_u64(header_size_buf); + if (header_size_ >= file_size_) { + set_error(error, "invalid safetensor file '" + file_path + "'"); + return false; + } + + // read header + std::vector header_buf; + header_buf.resize(header_size_ + 1); + header_buf[header_size_] = '\0'; + file.read(header_buf.data(), header_size_); + if (!file) { + set_error(error, "read safetensors header failed: '" + file_path + "'"); + return false; + } + + nlohmann::json header_; + try { + header_ = nlohmann::json::parse(header_buf.data()); + } catch (const std::exception&) { + set_error(error, "parsing safetensors header failed: '" + file_path + "'"); + return false; + } + + tensor_storages.clear(); + for (auto& item : header_.items()) { + std::string name = item.key(); + nlohmann::json tensor_info = item.value(); + // LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str()); + + if (name == "__metadata__") { + continue; + } + + std::string dtype = tensor_info["dtype"]; + nlohmann::json shape = tensor_info["shape"]; + + if (dtype == "U8") { + continue; + } + + size_t begin = tensor_info["data_offsets"][0].get(); + size_t end = tensor_info["data_offsets"][1].get(); + + ggml_type type = safetensors_dtype_to_ggml_type(dtype); + if (type == GGML_TYPE_COUNT) { + set_error(error, "unsupported dtype '" + dtype + "' (tensor '" + name + "')"); + return false; + } + + if (shape.size() > SD_MAX_DIMS) { + set_error(error, "invalid tensor '" + name + "'"); + return false; + } + + int n_dims = (int)shape.size(); + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + for (int i = 0; i < n_dims; i++) { + ne[i] = shape[i].get(); + } + + if (n_dims == 5) { + n_dims = 4; + ne[0] = ne[0] * ne[1]; + ne[1] = ne[2]; + ne[2] = ne[3]; + ne[3] = ne[4]; + } + + // ggml_n_dims returns 1 for scalars + if (n_dims == 0) { + n_dims = 1; + } + + TensorStorage tensor_storage(name, type, ne, n_dims, 0, ST_HEADER_SIZE_LEN + header_size_ + begin); + tensor_storage.reverse_ne(); + + size_t tensor_data_size = end - begin; + + bool tensor_size_ok; + if (dtype == "F8_E4M3") { + tensor_storage.is_f8_e4m3 = true; + // f8 -> f16 + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F8_E5M2") { + tensor_storage.is_f8_e5m2 = true; + // f8 -> f16 + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F64") { + tensor_storage.is_f64 = true; + // f64 -> f32 + tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); + } else if (dtype == "I64") { + tensor_storage.is_i64 = true; + // i64 -> i32 + tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size); + } else { + tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size); + } + if (!tensor_size_ok) { + set_error(error, "size mismatch for tensor '" + name + "' (" + dtype + ")"); + return false; + } + + tensor_storages.push_back(tensor_storage); + + // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); + } + + return true; +} + +static bool ggml_type_to_safetensors_dtype(ggml_type type, std::string* dtype) { + switch (type) { + case GGML_TYPE_F16: + *dtype = "F16"; + return true; + case GGML_TYPE_BF16: + *dtype = "BF16"; + return true; + case GGML_TYPE_F32: + *dtype = "F32"; + return true; + case GGML_TYPE_I32: + *dtype = "I32"; + return true; + default: + return false; + } +} + +bool write_safetensors_file(const std::string& file_path, + const std::vector& tensors, + std::string* error) { + nlohmann::ordered_json header = nlohmann::ordered_json::object(); + + uint64_t data_offset = 0; + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + if (tensor == nullptr) { + set_error(error, "null tensor cannot be written to safetensors"); + return false; + } + + const std::string name = ggml_get_name(tensor); + std::string dtype; + if (!ggml_type_to_safetensors_dtype(tensor->type, &dtype)) { + set_error(error, + "unsupported safetensors dtype '" + std::string(ggml_type_name(tensor->type)) + + "' for tensor '" + name + "'"); + return false; + } + + const uint64_t tensor_nbytes = ggml_nbytes(tensor); + + nlohmann::ordered_json json_tensor_info = nlohmann::ordered_json::object(); + json_tensor_info["dtype"] = dtype; + + nlohmann::ordered_json shape = nlohmann::ordered_json::array(); + for (int i = 0; i < write_tensor.n_dims; ++i) { + shape.push_back(write_tensor.ne[write_tensor.n_dims - 1 - i]); + } + json_tensor_info["shape"] = shape; + + nlohmann::ordered_json data_offsets = nlohmann::ordered_json::array(); + data_offsets.push_back(data_offset); + data_offsets.push_back(data_offset + tensor_nbytes); + json_tensor_info["data_offsets"] = data_offsets; + + header[name] = json_tensor_info; + data_offset += tensor_nbytes; + } + + const std::string header_str = header.dump(); + + std::ofstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "' for writing"); + return false; + } + + LOG_INFO("trying to save tensors to %s", file_path.c_str()); + model_io::write_u64(file, header_str.size()); + file.write(header_str.data(), header_str.size()); + if (!file) { + set_error(error, "failed to write safetensors header to '" + file_path + "'"); + return false; + } + + for (const TensorWriteInfo& write_tensor : tensors) { + ggml_tensor* tensor = write_tensor.tensor; + const std::string name = ggml_get_name(tensor); + const size_t tensor_nbytes = ggml_nbytes(tensor); + file.write((const char*)tensor->data, tensor_nbytes); + if (!file) { + set_error(error, + "failed to write tensor '" + name + "' to '" + file_path + "'"); + return false; + } + } + + return true; +} diff --git a/otherarch/sdcpp/model_io/safetensors_io.h b/otherarch/sdcpp/model_io/safetensors_io.h new file mode 100644 index 000000000..08a1bc1f3 --- /dev/null +++ b/otherarch/sdcpp/model_io/safetensors_io.h @@ -0,0 +1,17 @@ +#ifndef __SD_MODEL_IO_SAFETENSORS_IO_H__ +#define __SD_MODEL_IO_SAFETENSORS_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_safetensors_file(const std::string& file_path); +bool read_safetensors_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); +bool write_safetensors_file(const std::string& file_path, + const std::vector& tensors, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_SAFETENSORS_IO_H__ diff --git a/otherarch/sdcpp/model_io/tensor_storage.h b/otherarch/sdcpp/model_io/tensor_storage.h new file mode 100644 index 000000000..c0cf079c5 --- /dev/null +++ b/otherarch/sdcpp/model_io/tensor_storage.h @@ -0,0 +1,132 @@ +#ifndef __SD_TENSOR_STORAGE_H__ +#define __SD_TENSOR_STORAGE_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include "ggml.h" + +#define SD_MAX_DIMS 5 + +struct TensorStorage { + std::string name; + ggml_type type = GGML_TYPE_F32; + ggml_type expected_type = GGML_TYPE_COUNT; + bool is_f8_e4m3 = false; + bool is_f8_e5m2 = false; + bool is_f64 = false; + bool is_i64 = false; + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + int n_dims = 0; + + std::string storage_key; + size_t file_index = 0; + int index_in_zip = -1; // >= means stored in a zip file + uint64_t offset = 0; // offset in file + + TensorStorage() = default; + + TensorStorage(std::string name, ggml_type type, const int64_t* ne, int n_dims, size_t file_index, size_t offset = 0) + : name(std::move(name)), type(type), n_dims(n_dims), file_index(file_index), offset(offset) { + for (int i = 0; i < n_dims; i++) { + this->ne[i] = ne[i]; + } + } + + int64_t nelements() const { + int64_t n = 1; + for (int i = 0; i < SD_MAX_DIMS; i++) { + n *= ne[i]; + } + return n; + } + + int64_t nbytes() const { + return nelements() * ggml_type_size(type) / ggml_blck_size(type); + } + + int64_t nbytes_to_read() const { + if (is_f8_e4m3 || is_f8_e5m2) { + return nbytes() / 2; + } else if (is_f64 || is_i64) { + return nbytes() * 2; + } else { + return nbytes(); + } + } + + void unsqueeze() { + if (n_dims == 2) { + n_dims = 4; + ne[3] = ne[1]; + ne[2] = ne[0]; + ne[1] = 1; + ne[0] = 1; + } + } + + std::vector chunk(size_t n) { + std::vector chunks; + uint64_t chunk_size = nbytes_to_read() / n; + // printf("%d/%d\n", chunk_size, nbytes_to_read()); + reverse_ne(); + for (size_t i = 0; i < n; i++) { + TensorStorage chunk_i = *this; + chunk_i.ne[0] = ne[0] / n; + chunk_i.offset = offset + i * chunk_size; + chunk_i.reverse_ne(); + chunks.push_back(chunk_i); + } + reverse_ne(); + return chunks; + } + + void reverse_ne() { + int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + for (int i = 0; i < n_dims; i++) { + new_ne[i] = ne[n_dims - 1 - i]; + } + for (int i = 0; i < n_dims; i++) { + ne[i] = new_ne[i]; + } + } + + std::string to_string() const { + std::stringstream ss; + const char* type_name = ggml_type_name(type); + if (is_f8_e4m3) { + type_name = "f8_e4m3"; + } else if (is_f8_e5m2) { + type_name = "f8_e5m2"; + } else if (is_f64) { + type_name = "f64"; + } else if (is_i64) { + type_name = "i64"; + } + ss << name << " | " << type_name << " | "; + ss << n_dims << " ["; + for (int i = 0; i < SD_MAX_DIMS; i++) { + ss << ne[i]; + if (i != SD_MAX_DIMS - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } +}; + +struct TensorWriteInfo { + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + int n_dims = 0; + ggml_tensor* tensor = nullptr; +}; + +typedef std::function on_new_tensor_cb_t; + +#endif // __SD_TENSOR_STORAGE_H__ diff --git a/otherarch/sdcpp/model_io/torch_legacy_io.cpp b/otherarch/sdcpp/model_io/torch_legacy_io.cpp new file mode 100644 index 000000000..816547252 --- /dev/null +++ b/otherarch/sdcpp/model_io/torch_legacy_io.cpp @@ -0,0 +1,252 @@ +#include "torch_legacy_io.h" + +#include +#include +#include +#include +#include +#include + +#include "pickle_io.h" +#include "util.h" + +// torch.save format background: +// +// - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by +// default. +// - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive +// containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder. +// - The old format can still be produced explicitly with: +// torch.save(obj, path, _use_new_zipfile_serialization=False) +// +// Whether obj is a state_dict or a whole nn.Module does not change the outer +// container format selected by torch.save. It changes the pickled object inside: +// +// - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a +// restricted subset of this layout because tensor metadata and raw storages +// can be recovered without executing pickle callables. +// - whole module/checkpoint object: arbitrary Python object graph. This may +// require importing user classes and executing pickle GLOBAL/REDUCE rebuild +// logic, so it is intentionally not supported here. +// +// Legacy non-zip PyTorch files are not a single pickle object: +// +// 1. pickle object: PyTorch legacy magic number +// 2. pickle object: legacy protocol version, expected to be 1001 +// 3. pickle object: sys_info metadata, ignored by this reader +// 4. pickle object: state_dict metadata, parsed by pickle_io.cpp +// 5. pickle object: serialized storage key list, skipped here +// 6. raw storage data payloads +// - PyTorch writes storages after the pickles, ordered by storage key +// - each storage has an 8-byte legacy storage header followed by raw bytes +static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +static std::string bytes_to_hex(const std::vector& bytes) { + static const char* hex = "0123456789ABCDEF"; + std::string result; + result.reserve(bytes.size() * 3); + for (size_t i = 0; i < bytes.size(); ++i) { + if (i > 0) { + result.push_back('-'); + } + result.push_back(hex[(bytes[i] >> 4) & 0x0F]); + result.push_back(hex[bytes[i] & 0x0F]); + } + return result; +} + +static bool is_probably_tar_file(const std::vector& header) { + return header.size() >= 262 && + header[257] == 'u' && + header[258] == 's' && + header[259] == 't' && + header[260] == 'a' && + header[261] == 'r'; +} + +static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector& buffer) { + if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) { + return ""; + } + if (buffer.empty()) { + return "unsupported PyTorch file '" + file_path + "': empty file"; + } + + size_t short_len = std::min(buffer.size(), 32); + std::vector short_header(buffer.begin(), buffer.begin() + short_len); + const bool raw_pickle = buffer[0] == 0x80; + const bool tar_file = is_probably_tar_file(buffer); + + std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " + + bytes_to_hex(short_header) + + ", raw_pickle=" + (raw_pickle ? "true" : "false") + + ", tar=" + (tar_file ? "true" : "false"); + if (raw_pickle) { + message += "; raw pickle did not match the restricted state_dict layouts currently supported"; + } else if (tar_file) { + message += "; legacy tar PyTorch checkpoints are not supported yet"; + } + return message; +} + +bool read_torch_legacy_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + file.seekg(0, file.end); + size_t file_size = (size_t)file.tellg(); + file.seekg(0, file.beg); + if (file_size == 0) { + set_error(error, "empty file '" + file_path + "'"); + return false; + } + + std::vector buffer(file_size); + file.read((char*)buffer.data(), file_size); + if (!file) { + set_error(error, "failed to read '" + file_path + "'"); + return false; + } + + auto finalize_tensor_offsets = [&](size_t storage_data_offset, + const std::unordered_map& legacy_storage_map) -> bool { + if (storage_data_offset > file_size) { + return false; + } + + std::vector storage_keys; + storage_keys.reserve(legacy_storage_map.size()); + for (const auto& [storage_key, _] : legacy_storage_map) { + storage_keys.push_back(storage_key); + } + std::sort(storage_keys.begin(), storage_keys.end()); + + std::unordered_map storage_offsets; + uint64_t current_offset = storage_data_offset; + for (const auto& storage_key : storage_keys) { + auto it = legacy_storage_map.find(storage_key); + if (it == legacy_storage_map.end()) { + return false; + } + if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) { + return false; + } + storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE; + current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second; + } + + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.storage_key.empty()) { + continue; + } + + auto it_offset = storage_offsets.find(tensor_storage.storage_key); + auto it_size = legacy_storage_map.find(tensor_storage.storage_key); + if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) { + return false; + } + + uint64_t base_offset = it_offset->second; + uint64_t storage_nbytes = it_size->second; + uint64_t tensor_nbytes = tensor_storage.nbytes_to_read(); + if (tensor_storage.offset + tensor_nbytes > storage_nbytes) { + return false; + } + + tensor_storage.offset = base_offset + tensor_storage.offset; + tensor_storage.storage_key.clear(); + } + + return true; + }; + + auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool { + tensor_storages.clear(); + std::unordered_map legacy_storage_map; + if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset, + state_dict_size, + tensor_storages, + legacy_storage_map, + error)) { + return false; + } + + size_t offset_after_state_dict = state_dict_offset + state_dict_size; + size_t storage_keys_size = 0; + if (!skip_pickle_object(buffer.data() + offset_after_state_dict, + buffer.size() - offset_after_state_dict, + &storage_keys_size)) { + return false; + } + + *storage_data_offset = offset_after_state_dict + storage_keys_size; + return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map); + }; + + size_t object_size_1 = 0; + size_t offset = 0; + + if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) && + pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) { + offset += object_size_1; + + size_t object_size_2 = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + uint32_t protocol_version = 0; + if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + offset += object_size_2; + + size_t object_size_3 = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + offset += object_size_3; + + size_t state_dict_size = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + + size_t storage_data_offset = 0; + if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) { + return true; + } + + if (error != nullptr && error->empty()) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + } + return false; + } + + size_t state_dict_size = 0; + if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) { + size_t storage_data_offset = 0; + if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) { + return true; + } + } + + if (error != nullptr && error->empty()) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + } + return false; +} diff --git a/otherarch/sdcpp/model_io/torch_legacy_io.h b/otherarch/sdcpp/model_io/torch_legacy_io.h new file mode 100644 index 000000000..6680e02a1 --- /dev/null +++ b/otherarch/sdcpp/model_io/torch_legacy_io.h @@ -0,0 +1,13 @@ +#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__ +#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool read_torch_legacy_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_TORCH_LEGACY_IO_H__ diff --git a/otherarch/sdcpp/model_io/torch_zip_io.cpp b/otherarch/sdcpp/model_io/torch_zip_io.cpp new file mode 100644 index 000000000..9eaf6c53a --- /dev/null +++ b/otherarch/sdcpp/model_io/torch_zip_io.cpp @@ -0,0 +1,140 @@ +#include "torch_zip_io.h" + +#include +#include +#include +#include +#include + +#include "pickle_io.h" + +#include "zip.h" + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_torch_zip_file(const std::string& file_path) { + zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == nullptr) { + return false; + } + zip_close(zip); + return true; +} + +static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) { + size_t n = zip_entries_total(zip); + for (size_t i = 0; i < n; ++i) { + zip_entry_openbyindex(zip, i); + std::string name = zip_entry_name(zip); + if (name == entry_name) { + *index = (int)i; + *size = zip_entry_size(zip); + zip_entry_close(zip); + return true; + } + zip_entry_close(zip); + } + return false; +} + +static bool parse_zip_data_pkl(const uint8_t* buffer, + size_t buffer_size, + zip_t* zip, + const std::string& dir, + std::vector& tensor_storages, + std::string* error) { + std::vector parsed_tensors; + std::unordered_map storage_nbytes; + if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) { + if (error != nullptr && error->empty()) { + *error = "failed to parse torch zip pickle metadata"; + } + return false; + } + + for (auto& tensor_storage : parsed_tensors) { + if (tensor_storage.storage_key.empty()) { + set_error(error, "tensor '" + tensor_storage.name + "' has no storage key"); + return false; + } + + const std::string entry_name = dir + "data/" + tensor_storage.storage_key; + int zip_index = -1; + uint64_t entry_size = 0; + if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) { + set_error(error, "storage entry '" + entry_name + "' was not found"); + return false; + } + + auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key); + if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) { + set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata"); + return false; + } + + uint64_t tensor_nbytes = tensor_storage.nbytes_to_read(); + if (tensor_storage.offset + tensor_nbytes > entry_size) { + set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'"); + return false; + } + + tensor_storage.index_in_zip = zip_index; + tensor_storage.storage_key.clear(); + tensor_storages.push_back(tensor_storage); + } + + return true; +} + +bool read_torch_zip_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == nullptr) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + tensor_storages.clear(); + bool success = true; + bool found_data_pkl = false; + int n = (int)zip_entries_total(zip); + for (int i = 0; i < n; ++i) { + zip_entry_openbyindex(zip, i); + std::string name = zip_entry_name(zip); + size_t pos = name.find("data.pkl"); + if (pos != std::string::npos) { + found_data_pkl = true; + std::string dir = name.substr(0, pos); + void* pkl_data = nullptr; + size_t pkl_size = 0; + zip_entry_read(zip, &pkl_data, &pkl_size); + + if (pkl_data == nullptr || pkl_size == 0) { + set_error(error, "failed to read '" + name + "' from '" + file_path + "'"); + success = false; + } else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) { + success = false; + } + + free(pkl_data); + } + zip_entry_close(zip); + + if (!success) { + break; + } + } + + if (success && !found_data_pkl) { + set_error(error, "data.pkl was not found in '" + file_path + "'"); + success = false; + } + + zip_close(zip); + return success; +} diff --git a/otherarch/sdcpp/model_io/torch_zip_io.h b/otherarch/sdcpp/model_io/torch_zip_io.h new file mode 100644 index 000000000..54fb099a7 --- /dev/null +++ b/otherarch/sdcpp/model_io/torch_zip_io.h @@ -0,0 +1,14 @@ +#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__ +#define __SD_MODEL_IO_TORCH_ZIP_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_torch_zip_file(const std::string& file_path); +bool read_torch_zip_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_TORCH_ZIP_IO_H__ diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp index 00eb66707..e583cd6bd 100644 --- a/otherarch/sdcpp/sdtype_adapter.cpp +++ b/otherarch/sdcpp/sdtype_adapter.cpp @@ -21,7 +21,31 @@ #include "util.cpp" #include "name_conversion.cpp" #include "upscaler.cpp" + +#include "zip.c" +#include "model_io/binary_io.h" +namespace pickle { +#include "model_io/pickle_io.cpp" +} +namespace gguf { +#include "model_io/gguf_io.cpp" +} +namespace safetensors { +#include "model_io/safetensors_io.cpp" +} +using namespace pickle; +namespace torch_legacy { +#include "model_io/torch_legacy_io.cpp" +} +namespace torch_zip { +#include "model_io/torch_zip_io.cpp" +} +using namespace gguf; +using namespace safetensors; +using namespace torch_legacy; +using namespace torch_zip; #include "model.cpp" + #include "tokenizers/bpe_tokenizer.cpp" #include "tokenizers/clip_tokenizer.cpp" #include "tokenizers/mistral_tokenizer.cpp" @@ -29,7 +53,6 @@ #include "tokenizers/t5_unigram_tokenizer.cpp" #include "tokenizers/tokenizer.cpp" #include "tokenizers/tokenize_util.cpp" -#include "zip.c" #include "otherarch/utils.h" diff --git a/otherarch/sdcpp/stable-diffusion.cpp b/otherarch/sdcpp/stable-diffusion.cpp index a6ecd9e38..4b619bc5d 100644 --- a/otherarch/sdcpp/stable-diffusion.cpp +++ b/otherarch/sdcpp/stable-diffusion.cpp @@ -17,6 +17,7 @@ #include "pmid.hpp" #include "sample-cache.h" #include "tae.hpp" +#include "upscaler.h" #include "vae.hpp" #include "latent-preview.h" @@ -234,11 +235,11 @@ public: device = 0; } if (device >= device_count) { - LOG_WARN("Cannot find targeted vulkan device (%llu). Falling back to device 0.", device); + LOG_WARN("Cannot find targeted vulkan device (%zu). Falling back to device 0.", device); device = 0; } } - LOG_INFO("Vulkan: Using device %llu", device); + LOG_INFO("Vulkan: Using device %zu", device); backend = ggml_backend_vk_init(device); } if (!backend) { @@ -2380,6 +2381,35 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) { return LORA_APPLY_MODE_COUNT; } +const char* hires_upscaler_to_str[] = { + "None", + "Latent", + "Latent (nearest)", + "Latent (nearest-exact)", + "Latent (antialiased)", + "Latent (bicubic)", + "Latent (bicubic antialiased)", + "Lanczos", + "Nearest", + "Model", +}; + +const char* sd_hires_upscaler_name(enum sd_hires_upscaler_t upscaler) { + if (upscaler >= SD_HIRES_UPSCALER_NONE && upscaler < SD_HIRES_UPSCALER_COUNT) { + return hires_upscaler_to_str[upscaler]; + } + return NONE_STR; +} + +enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str) { + for (int i = 0; i < SD_HIRES_UPSCALER_COUNT; i++) { + if (!strcmp(str, hires_upscaler_to_str[i])) { + return (enum sd_hires_upscaler_t)i; + } + } + return SD_HIRES_UPSCALER_COUNT; +} + void sd_cache_params_init(sd_cache_params_t* cache_params) { *cache_params = {}; cache_params->mode = SD_CACHE_DISABLED; @@ -2408,6 +2438,19 @@ void sd_cache_params_init(sd_cache_params_t* cache_params) { cache_params->spectrum_stop_percent = 0.9f; } +void sd_hires_params_init(sd_hires_params_t* hires_params) { + *hires_params = {}; + hires_params->enabled = false; + hires_params->upscaler = SD_HIRES_UPSCALER_LATENT; + hires_params->model_path = nullptr; + hires_params->scale = 2.0f; + hires_params->target_width = 0; + hires_params->target_height = 0; + hires_params->steps = 0; + hires_params->denoising_strength = 0.7f; + hires_params->upscale_tile_size = 128; +} + void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { *sd_ctx_params = {}; sd_ctx_params->vae_decode_only = true; @@ -2577,6 +2620,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_cache_params_init(&sd_img_gen_params->cache); + sd_hires_params_init(&sd_img_gen_params->hires); } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -2603,7 +2647,8 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { "increase_ref_index: %s\n" "control_strength: %.2f\n" "photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n" - "VAE tiling: %s\n", + "VAE tiling: %s\n" + "hires: {enabled=%s, upscaler=%s, model_path=%s, scale=%.2f, target=%dx%d, steps=%d, denoising_strength=%.2f}\n", SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), sd_img_gen_params->clip_skip, @@ -2620,7 +2665,15 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->pm_params.style_strength, sd_img_gen_params->pm_params.id_images_count, SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), - BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled)); + BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled), + BOOL_STR(sd_img_gen_params->hires.enabled), + sd_hires_upscaler_name(sd_img_gen_params->hires.upscaler), + SAFE_STR(sd_img_gen_params->hires.model_path), + sd_img_gen_params->hires.scale, + sd_img_gen_params->hires.target_width, + sd_img_gen_params->hires.target_height, + sd_img_gen_params->hires.steps, + sd_img_gen_params->hires.denoising_strength); const char* cache_mode_str = "disabled"; if (sd_img_gen_params->cache.mode == SD_CACHE_EASYCACHE) { cache_mode_str = "easycache"; @@ -2724,8 +2777,10 @@ enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_me return EXPONENTIAL_SCHEDULER; } } - if (sample_method == LCM_SAMPLE_METHOD) { + if (sample_method == LCM_SAMPLE_METHOD || sample_method == TCD_SAMPLE_METHOD) { return LCM_SCHEDULER; + } else if (sample_method == DDIM_TRAILING_SAMPLE_METHOD) { + return SIMPLE_SCHEDULER; } return DISCRETE_SCHEDULER; } @@ -2799,6 +2854,7 @@ struct GenerationRequest { sd_guidance_params_t guidance = {}; sd_guidance_params_t high_noise_guidance = {}; sd_pm_params_t pm_params = {}; + sd_hires_params_t hires = {}; int frames = -1; float vace_strength = 1.f; @@ -2820,6 +2876,7 @@ struct GenerationRequest { auto_resize_ref_image = sd_img_gen_params->auto_resize_ref_image; guidance = sd_img_gen_params->sample_params.guidance; pm_params = sd_img_gen_params->pm_params; + hires = sd_img_gen_params->hires; cache_params = &sd_img_gen_params->cache; resolve(sd_ctx); } @@ -2842,26 +2899,76 @@ struct GenerationRequest { } void align_generation_request_size() { + align_image_size(&width, &height, "generation request"); + } + + void align_image_size(int* target_width, int* target_height, const char* label) { int spatial_multiple = vae_scale_factor * diffusion_model_down_factor; - int width_offset = align_up_offset(width, spatial_multiple); - int height_offset = align_up_offset(height, spatial_multiple); + int width_offset = align_up_offset(*target_width, spatial_multiple); + int height_offset = align_up_offset(*target_height, spatial_multiple); if (width_offset <= 0 && height_offset <= 0) { return; } - int original_width = width; - int original_height = height; + int original_width = *target_width; + int original_height = *target_height; - width += width_offset; - height += height_offset; - LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", + *target_width += width_offset; + *target_height += height_offset; + LOG_WARN("align %s up %dx%d to %dx%d (multiple=%d)", + label, original_width, original_height, - width, - height, + *target_width, + *target_height, spatial_multiple); } + void resolve_hires() { + if (!hires.enabled) { + return; + } + if (hires.upscaler == SD_HIRES_UPSCALER_NONE) { + hires.enabled = false; + return; + } + if (hires.upscaler < SD_HIRES_UPSCALER_NONE || hires.upscaler >= SD_HIRES_UPSCALER_COUNT) { + LOG_WARN("hires upscaler '%d' is invalid, disabling hires", hires.upscaler); + hires.enabled = false; + return; + } + if (hires.upscaler == SD_HIRES_UPSCALER_MODEL && strlen(SAFE_STR(hires.model_path)) == 0) { + LOG_WARN("hires model upscaler requires a model path, disabling hires"); + hires.enabled = false; + return; + } + if (hires.scale <= 0.f && hires.target_width <= 0 && hires.target_height <= 0) { + LOG_WARN("hires scale must be positive when no target size is set, disabling hires"); + hires.enabled = false; + return; + } + hires.denoising_strength = std::clamp(hires.denoising_strength, 0.0001f, 1.f); + hires.steps = std::max(0, hires.steps); + + if (hires.target_width > 0 && hires.target_height > 0) { + // pass + } else if (hires.target_width > 0) { + hires.target_height = hires.target_width; + } else if (hires.target_height > 0) { + hires.target_width = hires.target_height; + } else { + hires.target_width = static_cast(std::round(width * hires.scale)); + hires.target_height = static_cast(std::round(height * hires.scale)); + } + + if (hires.target_width <= 0 || hires.target_height <= 0) { + LOG_WARN("hires target size is not positive, disabling hires"); + hires.enabled = false; + return; + } + align_image_size(&hires.target_width, &hires.target_height, "hires target"); + } + static void resolve_guidance(sd_ctx_t* sd_ctx, sd_guidance_params_t* guidance, bool* use_uncond, @@ -2902,6 +3009,7 @@ struct GenerationRequest { void resolve(sd_ctx_t* sd_ctx) { align_generation_request_size(); + resolve_hires(); seed = resolve_seed(seed); resolve_guidance(sd_ctx, &guidance, &use_uncond, &use_img_cond); @@ -3392,7 +3500,7 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, } decoded_images.push_back(std::move(image)); int64_t t2 = ggml_time_ms(); - LOG_INFO("latent %" PRId64 " decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000); + LOG_INFO("latent %zu decoded, taking %.2fs", i + 1, (t2 - t1) * 1.0f / 1000); } int64_t t4 = ggml_time_ms(); @@ -3414,6 +3522,135 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, return result_images; } +static sd::Tensor upscale_hires_latent(sd_ctx_t* sd_ctx, + const sd::Tensor& latent, + const GenerationRequest& request, + UpscalerGGML* upscaler) { + auto get_hires_latent_target_shape = [&]() { + std::vector target_shape = latent.shape(); + if (target_shape.size() < 2) { + target_shape.clear(); + return target_shape; + } + target_shape[0] = request.hires.target_width / request.vae_scale_factor; + target_shape[1] = request.hires.target_height / request.vae_scale_factor; + return target_shape; + }; + + if (request.hires.upscaler == SD_HIRES_UPSCALER_LATENT || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_NEAREST || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_ANTIALIASED || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_BICUBIC || + request.hires.upscaler == SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED) { + std::vector target_shape = get_hires_latent_target_shape(); + if (target_shape.empty()) { + LOG_ERROR("latent has invalid shape for hires upscale"); + return {}; + } + + sd::ops::InterpolateMode mode = sd::ops::InterpolateMode::Nearest; + bool antialias = false; + switch (request.hires.upscaler) { + case SD_HIRES_UPSCALER_LATENT: + mode = sd::ops::InterpolateMode::Bilinear; + break; + case SD_HIRES_UPSCALER_LATENT_NEAREST: + mode = sd::ops::InterpolateMode::Nearest; + break; + case SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT: + mode = sd::ops::InterpolateMode::NearestExact; + break; + case SD_HIRES_UPSCALER_LATENT_ANTIALIASED: + mode = sd::ops::InterpolateMode::Bilinear; + antialias = true; + break; + case SD_HIRES_UPSCALER_LATENT_BICUBIC: + mode = sd::ops::InterpolateMode::Bicubic; + break; + case SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED: + mode = sd::ops::InterpolateMode::Bicubic; + antialias = true; + break; + default: + break; + } + + LOG_INFO("hires %s upscale %" PRId64 "x%" PRId64 " -> %" PRId64 "x%" PRId64, + sd_hires_upscaler_name(request.hires.upscaler), + latent.shape()[0], + latent.shape()[1], + target_shape[0], + target_shape[1]); + + return sd::ops::interpolate(latent, target_shape, mode, false, antialias); + } else if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL || + request.hires.upscaler == SD_HIRES_UPSCALER_LANCZOS || + request.hires.upscaler == SD_HIRES_UPSCALER_NEAREST) { + if (sd_ctx->sd->vae_decode_only) { + LOG_ERROR("hires %s upscaler requires VAE encoder weights; create the context with vae_decode_only=false", + sd_hires_upscaler_name(request.hires.upscaler)); + return {}; + } + if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL && upscaler == nullptr) { + LOG_ERROR("hires model upscaler context is null"); + return {}; + } + + sd::Tensor decoded = sd_ctx->sd->decode_first_stage(latent); + if (decoded.empty()) { + LOG_ERROR("decode_first_stage failed before hires %s upscale", + sd_hires_upscaler_name(request.hires.upscaler)); + return {}; + } + + sd::Tensor upscaled_tensor; + if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) { + upscaled_tensor = upscaler->upscale_tensor(decoded); + if (upscaled_tensor.empty()) { + LOG_ERROR("hires model upscale failed"); + return {}; + } + + if (upscaled_tensor.shape()[0] != request.hires.target_width || + upscaled_tensor.shape()[1] != request.hires.target_height) { + upscaled_tensor = sd::ops::interpolate(upscaled_tensor, + {request.hires.target_width, + request.hires.target_height, + upscaled_tensor.shape()[2], + upscaled_tensor.shape()[3]}); + } + } else { + sd::ops::InterpolateMode mode = request.hires.upscaler == SD_HIRES_UPSCALER_LANCZOS + ? sd::ops::InterpolateMode::Lanczos + : sd::ops::InterpolateMode::Nearest; + LOG_INFO("hires %s image upscale %" PRId64 "x%" PRId64 " -> %dx%d", + sd_hires_upscaler_name(request.hires.upscaler), + decoded.shape()[0], + decoded.shape()[1], + request.hires.target_width, + request.hires.target_height); + upscaled_tensor = sd::ops::interpolate(decoded, + {request.hires.target_width, + request.hires.target_height, + decoded.shape()[2], + decoded.shape()[3]}, + mode); + upscaled_tensor = sd::ops::clamp(upscaled_tensor, 0.0f, 1.0f); + } + + sd::Tensor upscaled_latent = sd_ctx->sd->encode_first_stage(upscaled_tensor); + if (upscaled_latent.empty()) { + LOG_ERROR("encode_first_stage failed after hires %s upscale", + sd_hires_upscaler_name(request.hires.upscaler)); + } + return upscaled_latent; + } + + LOG_ERROR("unsupported hires upscaler '%s'", sd_hires_upscaler_name(request.hires.upscaler)); + return {}; +} + SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { if (sd_ctx == nullptr || sd_img_gen_params == nullptr) { return nullptr; @@ -3501,14 +3738,139 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s } return nullptr; } - if (sd_ctx->sd->free_params_immediately) { + if (sd_ctx->sd->free_params_immediately && !request.hires.enabled) { sd_ctx->sd->diffusion_model->free_params_buffer(); } int64_t denoise_end = ggml_time_ms(); - LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", + LOG_INFO("generating %zu latent images completed, taking %.2fs", final_latents.size(), (denoise_end - denoise_start) * 1.0f / 1000); + if (request.hires.enabled && request.hires.target_width > 0) { + LOG_INFO("hires fix: upscaling to %dx%d", request.hires.target_width, request.hires.target_height); + + std::unique_ptr hires_upscaler; + if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) { + LOG_INFO("hires fix: loading model upscaler from '%s'", request.hires.model_path); + hires_upscaler = std::make_unique(sd_ctx->sd->n_threads, + false, + request.hires.upscale_tile_size); + if (!hires_upscaler->load_from_file(request.hires.model_path, + sd_ctx->sd->offload_params_to_cpu, + sd_ctx->sd->n_threads)) { + LOG_ERROR("load hires model upscaler failed"); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + return nullptr; + } + } + + int hires_steps = request.hires.steps > 0 ? request.hires.steps : plan.sample_steps; + + // sd-webui behavior: scale up total steps so trimming by denoising_strength yields exactly hires_steps effective steps, + // unlike img2img which trims from a fixed step count + hires_steps = static_cast(hires_steps / request.hires.denoising_strength); + + std::vector hires_sigmas = sd_ctx->sd->denoiser->get_sigmas( + hires_steps, + sd_ctx->sd->get_image_seq_len(request.hires.target_height, request.hires.target_width), + sd_img_gen_params->sample_params.scheduler, + sd_ctx->sd->version); + + size_t t_enc = static_cast(hires_steps * request.hires.denoising_strength); + if (t_enc >= static_cast(hires_steps)) { + t_enc = static_cast(hires_steps) - 1; + } + std::vector hires_sigma_sched(hires_sigmas.begin() + hires_steps - static_cast(t_enc) - 1, + hires_sigmas.end()); + LOG_INFO("hires fix: %d steps, denoising_strength=%.2f, sigma_sched_size=%zu", + hires_steps, + request.hires.denoising_strength, + hires_sigma_sched.size()); + + std::vector> hires_final_latents; + int64_t hires_denoise_start = ggml_time_ms(); + for (int b = 0; b < (int)final_latents.size(); b++) { + int64_t cur_seed = request.seed + b; + sd_ctx->sd->rng->manual_seed(cur_seed); + sd_ctx->sd->sampler_rng->manual_seed(cur_seed); + + sd::Tensor upscaled = upscale_hires_latent(sd_ctx, + final_latents[b], + request, + hires_upscaler.get()); + if (upscaled.empty()) { + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + return nullptr; + } + + sd::Tensor noise = sd::randn_like(upscaled, sd_ctx->sd->rng); + + sd::Tensor hires_denoise_mask; + if (!latents.denoise_mask.empty()) { + std::vector mask_shape = latents.denoise_mask.shape(); + mask_shape[0] = upscaled.shape()[0]; + mask_shape[1] = upscaled.shape()[1]; + hires_denoise_mask = sd::ops::interpolate(latents.denoise_mask, + mask_shape, + sd::ops::InterpolateMode::NearestMax); + } + + int64_t hires_sample_start = ggml_time_ms(); + sd::Tensor x_0 = sd_ctx->sd->sample(sd_ctx->sd->diffusion_model, + true, + upscaled, + std::move(noise), + embeds.cond, + embeds.uncond, + embeds.img_cond, + embeds.id_cond, + latents.control_image, + request.control_strength, + request.guidance, + plan.eta, + request.shifted_timestep, + plan.sample_method, + sd_ctx->sd->is_flow_denoiser(), + hires_sigma_sched, + plan.start_merge_step, + latents.ref_latents, + request.increase_ref_index, + hires_denoise_mask, + sd::Tensor(), + 1.f, + request.cache_params); + int64_t hires_sample_end = ggml_time_ms(); + if (!x_0.empty()) { + LOG_INFO("hires sampling %d/%d completed, taking %.2fs", + b + 1, + (int)final_latents.size(), + (hires_sample_end - hires_sample_start) * 1.0f / 1000); + hires_final_latents.push_back(std::move(x_0)); + continue; + } + + LOG_ERROR("hires sampling for image %d/%d failed after %.2fs", + b + 1, + (int)final_latents.size(), + (hires_sample_end - hires_sample_start) * 1.0f / 1000); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + return nullptr; + } + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + int64_t hires_denoise_end = ggml_time_ms(); + LOG_INFO("hires fix completed, taking %.2fs", (hires_denoise_end - hires_denoise_start) * 1.0f / 1000); + + final_latents = std::move(hires_final_latents); + } + auto result = decode_image_outputs(sd_ctx, request, final_latents); if (result == nullptr) { return nullptr; diff --git a/otherarch/sdcpp/stable-diffusion.h b/otherarch/sdcpp/stable-diffusion.h index 076e9cb27..0841a3c96 100644 --- a/otherarch/sdcpp/stable-diffusion.h +++ b/otherarch/sdcpp/stable-diffusion.h @@ -290,6 +290,32 @@ typedef struct { const char* path; } sd_lora_t; +enum sd_hires_upscaler_t { + SD_HIRES_UPSCALER_NONE, + SD_HIRES_UPSCALER_LATENT, + SD_HIRES_UPSCALER_LATENT_NEAREST, + SD_HIRES_UPSCALER_LATENT_NEAREST_EXACT, + SD_HIRES_UPSCALER_LATENT_ANTIALIASED, + SD_HIRES_UPSCALER_LATENT_BICUBIC, + SD_HIRES_UPSCALER_LATENT_BICUBIC_ANTIALIASED, + SD_HIRES_UPSCALER_LANCZOS, + SD_HIRES_UPSCALER_NEAREST, + SD_HIRES_UPSCALER_MODEL, + SD_HIRES_UPSCALER_COUNT, +}; + +typedef struct { + bool enabled; + enum sd_hires_upscaler_t upscaler; + const char* model_path; + float scale; + int target_width; + int target_height; + int steps; + float denoising_strength; + int upscale_tile_size; +} sd_hires_params_t; + typedef struct { const sd_lora_t* loras; uint32_t lora_count; @@ -313,6 +339,7 @@ typedef struct { sd_pm_params_t pm_params; sd_tiling_params_t vae_tiling_params; sd_cache_params_t cache; + sd_hires_params_t hires; } sd_img_gen_params_t; typedef struct { @@ -366,8 +393,11 @@ SD_API const char* sd_preview_name(enum preview_t preview); SD_API enum preview_t str_to_preview(const char* str); SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode); SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str); +SD_API const char* sd_hires_upscaler_name(enum sd_hires_upscaler_t upscaler); +SD_API enum sd_hires_upscaler_t str_to_sd_hires_upscaler(const char* str); SD_API void sd_cache_params_init(sd_cache_params_t* cache_params); +SD_API void sd_hires_params_init(sd_hires_params_t* hires_params); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); diff --git a/otherarch/sdcpp/tensor.hpp b/otherarch/sdcpp/tensor.hpp index 33302b056..f45551940 100644 --- a/otherarch/sdcpp/tensor.hpp +++ b/otherarch/sdcpp/tensor.hpp @@ -815,11 +815,202 @@ namespace sd { namespace ops { enum class InterpolateMode { Nearest, + NearestExact, NearestMax, NearestMin, NearestAvg, + Bilinear, + Bicubic, + Lanczos, }; + inline bool is_nearest_like_interpolate_mode(InterpolateMode mode) { + return mode == InterpolateMode::Nearest || + mode == InterpolateMode::NearestExact || + mode == InterpolateMode::NearestMax || + mode == InterpolateMode::NearestMin || + mode == InterpolateMode::NearestAvg; + } + + inline bool is_2d_filter_interpolate_mode(InterpolateMode mode) { + return mode == InterpolateMode::Bilinear || + mode == InterpolateMode::Bicubic || + mode == InterpolateMode::Lanczos; + } + + inline int64_t nearest_exact_interpolate_index(int64_t output_index, + int64_t input_size, + int64_t output_size) { + const double scale = static_cast(input_size) / static_cast(output_size); + const double center = (static_cast(output_index) + 0.5) * scale - 0.5; + return std::min(std::max(static_cast(std::floor(center + 0.5)), 0), input_size - 1); + } + + inline double linear_interpolate_weight(double x) { + x = std::abs(x); + return x < 1.0 ? 1.0 - x : 0.0; + } + + inline double cubic_interpolate_weight(double x) { + constexpr double a = -0.75; // Match PyTorch bicubic interpolation. + x = std::abs(x); + if (x <= 1.0) { + return ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0; + } + if (x < 2.0) { + return ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a; + } + return 0.0; + } + + inline double sinc(double x) { + constexpr double pi = 3.14159265358979323846; + if (std::abs(x) < 1e-12) { + return 1.0; + } + const double pix = pi * x; + return std::sin(pix) / pix; + } + + inline double lanczos_interpolate_weight(double x) { + constexpr double radius = 3.0; + x = std::abs(x); + if (x >= radius) { + return 0.0; + } + return sinc(x) * sinc(x / radius); + } + + struct InterpolateContributor { + int64_t index; + double weight; + }; + + inline std::vector> make_interpolate_contributors( + int64_t input_size, + int64_t output_size, + InterpolateMode mode, + bool antialias) { + std::vector> contributors(static_cast(output_size)); + const double scale = static_cast(input_size) / static_cast(output_size); + const double filter_scale = antialias ? std::max(1.0, scale) : 1.0; + + for (int64_t out = 0; out < output_size; ++out) { + const double center = (static_cast(out) + 0.5) * scale - 0.5; + int64_t start = 0; + int64_t end = 0; + + if (mode == InterpolateMode::Bilinear) { + const double support = filter_scale; + start = static_cast(std::ceil(center - support)); + end = static_cast(std::floor(center + support)); + } else if (mode == InterpolateMode::Bicubic) { + const double support = 2.0 * filter_scale; + start = static_cast(std::ceil(center - support)); + end = static_cast(std::floor(center + support)); + } else if (mode == InterpolateMode::Lanczos) { + const double support = 3.0 * filter_scale; + start = static_cast(std::ceil(center - support)); + end = static_cast(std::floor(center + support)); + } else { + tensor_throw_invalid_argument("Unsupported 2D filter interpolate mode: mode=" + + std::to_string(static_cast(mode))); + } + + double weight_sum = 0.0; + std::vector& axis_contributors = contributors[static_cast(out)]; + axis_contributors.reserve(static_cast(end - start + 1)); + + for (int64_t in = start; in <= end; ++in) { + double weight = 0.0; + if (mode == InterpolateMode::Bilinear) { + weight = linear_interpolate_weight((center - static_cast(in)) / filter_scale); + } else if (mode == InterpolateMode::Bicubic) { + weight = cubic_interpolate_weight((center - static_cast(in)) / filter_scale); + } else { + weight = lanczos_interpolate_weight((center - static_cast(in)) / filter_scale); + } + + if (weight == 0.0) { + continue; + } + + const int64_t clamped_index = std::min(std::max(in, 0), input_size - 1); + axis_contributors.push_back({clamped_index, weight}); + weight_sum += weight; + } + + if ((antialias || mode == InterpolateMode::Lanczos) && + std::abs(weight_sum) > 1e-12) { + for (auto& contributor : axis_contributors) { + contributor.weight /= weight_sum; + } + } + + if (axis_contributors.empty()) { + const int64_t nearest = std::min( + std::max(static_cast(std::floor(center + 0.5)), 0), + input_size - 1); + axis_contributors.push_back({nearest, 1.0}); + } + } + + return contributors; + } + + template + inline Tensor interpolate_2d_filter(const Tensor& input, + const std::vector& output_shape, + InterpolateMode mode, + bool antialias) { + if (input.dim() < 2) { + tensor_throw_invalid_argument("2D filter interpolate requires rank >= 2: input_shape=" + + tensor_shape_to_string(input.shape()) + ", output_shape=" + + tensor_shape_to_string(output_shape)); + } + for (size_t i = 2; i < output_shape.size(); ++i) { + if (input.shape()[i] != output_shape[i]) { + tensor_throw_invalid_argument("2D filter interpolate only supports resizing dimensions 0 and 1: input_shape=" + + tensor_shape_to_string(input.shape()) + ", output_shape=" + + tensor_shape_to_string(output_shape)); + } + } + + Tensor output(output_shape); + const int64_t input_width = input.shape()[0]; + const int64_t input_height = input.shape()[1]; + const int64_t output_width = output_shape[0]; + const int64_t output_height = output_shape[1]; + const int64_t input_plane = input_width * input_height; + const int64_t output_plane = output_width * output_height; + const int64_t plane_count = input.numel() / input_plane; + + auto x_contributors = make_interpolate_contributors(input_width, output_width, mode, antialias); + auto y_contributors = make_interpolate_contributors(input_height, output_height, mode, antialias); + + for (int64_t plane = 0; plane < plane_count; ++plane) { + const int64_t input_plane_offset = plane * input_plane; + const int64_t output_plane_offset = plane * output_plane; + for (int64_t y = 0; y < output_height; ++y) { + const auto& y_axis = y_contributors[static_cast(y)]; + for (int64_t x = 0; x < output_width; ++x) { + const auto& x_axis = x_contributors[static_cast(x)]; + double value = 0.0; + for (const auto& yc : y_axis) { + const int64_t input_row_offset = input_plane_offset + yc.index * input_width; + for (const auto& xc : x_axis) { + value += static_cast(input.data()[input_row_offset + xc.index]) * + xc.weight * yc.weight; + } + } + output.data()[output_plane_offset + y * output_width + x] = static_cast(value); + } + } + } + + return output; + } + inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) { if (index < 0) { index += dim_size; @@ -1014,17 +1205,20 @@ namespace sd { inline Tensor interpolate(const Tensor& input, std::vector output_shape, InterpolateMode mode = InterpolateMode::Nearest, - bool align_corners = false) { - const bool is_nearest_like_mode = (mode == InterpolateMode::Nearest || - mode == InterpolateMode::NearestMax || - mode == InterpolateMode::NearestMin || - mode == InterpolateMode::NearestAvg); - if (!is_nearest_like_mode) { - tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" + + bool align_corners = false, + bool antialias = false) { + const bool is_nearest_like_mode = is_nearest_like_interpolate_mode(mode); + const bool is_2d_filter_mode = is_2d_filter_interpolate_mode(mode); + if (!is_nearest_like_mode && !is_2d_filter_mode) { + tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + + std::to_string(static_cast(mode))); + } + if (antialias && !is_2d_filter_mode) { + tensor_throw_invalid_argument("Tensor interpolate antialias requires a 2D filter mode: mode=" + std::to_string(static_cast(mode))); } if (align_corners) { - tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" + + tensor_throw_invalid_argument("align_corners is not supported for tensor interpolate: input_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(output_shape)); } @@ -1051,6 +1245,10 @@ namespace sd { } } + if (is_2d_filter_mode) { + return interpolate_2d_filter(input, output_shape, mode, antialias); + } + bool has_downsampling = false; for (int64_t i = 0; i < input.dim(); ++i) { if (input.shape()[i] > output_shape[i]) { @@ -1060,12 +1258,20 @@ namespace sd { } Tensor output(std::move(output_shape)); - if (mode == InterpolateMode::Nearest || !has_downsampling) { + if (mode == InterpolateMode::Nearest || + mode == InterpolateMode::NearestExact || + !has_downsampling) { for (int64_t flat = 0; flat < output.numel(); ++flat) { std::vector output_coord = tensor_unravel_index(flat, output.shape()); std::vector input_coord(static_cast(input.dim()), 0); for (size_t i = 0; i < static_cast(input.dim()); ++i) { - input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i]; + if (mode == InterpolateMode::NearestExact) { + input_coord[i] = nearest_exact_interpolate_index(output_coord[i], + input.shape()[i], + output.shape()[i]); + } else { + input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i]; + } } output[flat] = input.index(input_coord); } @@ -1083,6 +1289,12 @@ namespace sd { return T(0); case InterpolateMode::Nearest: return T(0); + case InterpolateMode::NearestExact: + return T(0); + case InterpolateMode::Bilinear: + case InterpolateMode::Bicubic: + case InterpolateMode::Lanczos: + break; } tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + @@ -1102,6 +1314,12 @@ namespace sd { break; case InterpolateMode::Nearest: break; + case InterpolateMode::NearestExact: + break; + case InterpolateMode::Bilinear: + case InterpolateMode::Bicubic: + case InterpolateMode::Lanczos: + break; } }; @@ -1157,17 +1375,20 @@ namespace sd { const std::optional>& size, const std::optional>& scale_factor, InterpolateMode mode = InterpolateMode::Nearest, - bool align_corners = false) { - const bool is_nearest_like_mode = (mode == InterpolateMode::Nearest || - mode == InterpolateMode::NearestMax || - mode == InterpolateMode::NearestMin || - mode == InterpolateMode::NearestAvg); - if (!is_nearest_like_mode) { - tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" + + bool align_corners = false, + bool antialias = false) { + const bool is_nearest_like_mode = is_nearest_like_interpolate_mode(mode); + const bool is_2d_filter_mode = is_2d_filter_interpolate_mode(mode); + if (!is_nearest_like_mode && !is_2d_filter_mode) { + tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" + + std::to_string(static_cast(mode))); + } + if (antialias && !is_2d_filter_mode) { + tensor_throw_invalid_argument("Tensor interpolate antialias requires a 2D filter mode: mode=" + std::to_string(static_cast(mode))); } if (align_corners) { - tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" + + tensor_throw_invalid_argument("align_corners is not supported for tensor interpolate: input_shape=" + tensor_shape_to_string(input.shape())); } if (size.has_value() == scale_factor.has_value()) { @@ -1211,7 +1432,7 @@ namespace sd { } } - return interpolate(input, std::move(output_shape), mode, align_corners); + return interpolate(input, std::move(output_shape), mode, align_corners, antialias); } template @@ -1219,12 +1440,14 @@ namespace sd { const std::optional>& size, double scale_factor, InterpolateMode mode = InterpolateMode::Nearest, - bool align_corners = false) { + bool align_corners = false, + bool antialias = false) { return interpolate(input, size, std::vector(size.has_value() ? size->size() : input.dim(), scale_factor), mode, - align_corners); + align_corners, + antialias); } template diff --git a/otherarch/sdcpp/tokenizers/clip_tokenizer.cpp b/otherarch/sdcpp/tokenizers/clip_tokenizer.cpp index 57319306f..70d637724 100644 --- a/otherarch/sdcpp/tokenizers/clip_tokenizer.cpp +++ b/otherarch/sdcpp/tokenizers/clip_tokenizer.cpp @@ -62,7 +62,7 @@ void CLIPTokenizer::load_from_merges(const std::string& merges_utf8_str) { } vocab.push_back(utf8_to_utf32("<|startoftext|>")); vocab.push_back(utf8_to_utf32("<|endoftext|>")); - LOG_DEBUG("vocab size: %llu", vocab.size()); + LOG_DEBUG("vocab size: %zu", vocab.size()); int i = 0; for (const auto& token : vocab) { encoder[token] = i; diff --git a/otherarch/sdcpp/tokenizers/mistral_tokenizer.cpp b/otherarch/sdcpp/tokenizers/mistral_tokenizer.cpp index 0a56542aa..9b0624e3a 100644 --- a/otherarch/sdcpp/tokenizers/mistral_tokenizer.cpp +++ b/otherarch/sdcpp/tokenizers/mistral_tokenizer.cpp @@ -28,7 +28,7 @@ void MistralTokenizer::load_from_merges(const std::string& merges_utf8_str, cons byte_decoder[pair.second] = pair.first; } std::vector merges = split_utf32(merges_utf8_str); - LOG_DEBUG("merges size %llu", merges.size()); + LOG_DEBUG("merges size %zu", merges.size()); std::vector> merge_pairs; for (const auto& merge : merges) { size_t space_pos = merge.find(' '); diff --git a/otherarch/sdcpp/tokenizers/qwen2_tokenizer.cpp b/otherarch/sdcpp/tokenizers/qwen2_tokenizer.cpp index 5ddaf4ed1..9929ea387 100644 --- a/otherarch/sdcpp/tokenizers/qwen2_tokenizer.cpp +++ b/otherarch/sdcpp/tokenizers/qwen2_tokenizer.cpp @@ -11,7 +11,7 @@ void Qwen2Tokenizer::load_from_merges(const std::string& merges_utf8_str) { } std::vector merges = split_utf32(merges_utf8_str); - LOG_DEBUG("merges size %llu", merges.size()); + LOG_DEBUG("merges size %zu", merges.size()); std::vector> merge_pairs; for (const auto& merge : merges) { size_t space_pos = merge.find(' '); diff --git a/otherarch/sdcpp/upscaler.cpp b/otherarch/sdcpp/upscaler.cpp index 03f7714e5..ed7bb89a0 100644 --- a/otherarch/sdcpp/upscaler.cpp +++ b/otherarch/sdcpp/upscaler.cpp @@ -1,125 +1,115 @@ -#include "esrgan.hpp" +#include "upscaler.h" #include "ggml_extend.hpp" #include "model.h" #include "stable-diffusion.h" #include "util.h" -struct UpscalerGGML { - ggml_backend_t backend = nullptr; // general backend - ggml_type model_data_type = GGML_TYPE_F16; - std::shared_ptr esrgan_upscaler; - std::string esrgan_path; - int n_threads; - bool direct = false; - int tile_size = 128; +UpscalerGGML::UpscalerGGML(int n_threads, + bool direct, + int tile_size) + : n_threads(n_threads), + direct(direct), + tile_size(tile_size) { +} - UpscalerGGML(int n_threads, - bool direct = false, - int tile_size = 128) - : n_threads(n_threads), - direct(direct), - tile_size(tile_size) { - } - - bool load_from_file(const std::string& esrgan_path, - bool offload_params_to_cpu, - int n_threads) { - ggml_log_set(ggml_log_callback_default, nullptr); +bool UpscalerGGML::load_from_file(const std::string& esrgan_path, + bool offload_params_to_cpu, + int n_threads) { + ggml_log_set(ggml_log_callback_default, nullptr); #ifdef SD_USE_CUDA - LOG_DEBUG("Using CUDA backend"); - backend = ggml_backend_cuda_init(0); + LOG_DEBUG("Using CUDA backend"); + backend = ggml_backend_cuda_init(0); #endif #ifdef SD_USE_METAL - LOG_DEBUG("Using Metal backend"); - backend = ggml_backend_metal_init(); + LOG_DEBUG("Using Metal backend"); + backend = ggml_backend_metal_init(); #endif #ifdef SD_USE_VULKAN - LOG_DEBUG("Using Vulkan backend"); - backend = ggml_backend_vk_init(0); + LOG_DEBUG("Using Vulkan backend"); + backend = ggml_backend_vk_init(0); #endif #ifdef SD_USE_OPENCL - LOG_DEBUG("Using OpenCL backend"); - backend = ggml_backend_opencl_init(); + LOG_DEBUG("Using OpenCL backend"); + backend = ggml_backend_opencl_init(); #endif #ifdef SD_USE_SYCL - LOG_DEBUG("Using SYCL backend"); - backend = ggml_backend_sycl_init(0); + LOG_DEBUG("Using SYCL backend"); + backend = ggml_backend_sycl_init(0); #endif - ModelLoader model_loader; - if (!model_loader.init_from_file_and_convert_name(esrgan_path)) { - LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); - } - model_loader.set_wtype_override(model_data_type); - if (!backend) { - LOG_DEBUG("Using CPU backend"); - backend = ggml_backend_cpu_init(); - } - LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); - esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map()); - if (direct) { - esrgan_upscaler->set_conv2d_direct_enabled(true); - } - if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) { - return false; - } - return true; + ModelLoader model_loader; + if (!model_loader.init_from_file_and_convert_name(esrgan_path)) { + LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); } - - sd::Tensor upscale_tensor(const sd::Tensor& input_tensor) { - sd::Tensor upscaled; - if (tile_size <= 0 || (input_tensor.shape()[0] <= tile_size && input_tensor.shape()[1] <= tile_size)) { - upscaled = esrgan_upscaler->compute(n_threads, input_tensor); - } else { - auto on_processing = [&](const sd::Tensor& input_tile) -> sd::Tensor { - auto output_tile = esrgan_upscaler->compute(n_threads, input_tile); - if (output_tile.empty()) { - LOG_ERROR("esrgan compute failed while processing a tile"); - return {}; - } - return output_tile; - }; - - upscaled = process_tiles_2d(input_tensor, - static_cast(input_tensor.shape()[0] * esrgan_upscaler->scale), - static_cast(input_tensor.shape()[1] * esrgan_upscaler->scale), - esrgan_upscaler->scale, - tile_size, - tile_size, - 0.25f, - false, - false, - on_processing); - } - esrgan_upscaler->free_compute_buffer(); - if (upscaled.empty()) { - LOG_ERROR("esrgan compute failed"); - return {}; - } - return upscaled; + model_loader.set_wtype_override(model_data_type); + if (!backend) { + LOG_DEBUG("Using CPU backend"); + backend = ggml_backend_cpu_init(); } + LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); + esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map()); + if (direct) { + esrgan_upscaler->set_conv2d_direct_enabled(true); + } + if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) { + return false; + } + return true; +} - sd_image_t upscale(sd_image_t input_image, uint32_t upscale_factor) { - // upscale_factor, unused for RealESRGAN_x4plus_anime_6B.pth - sd_image_t upscaled_image = {0, 0, 0, nullptr}; - int output_width = (int)input_image.width * esrgan_upscaler->scale; - int output_height = (int)input_image.height * esrgan_upscaler->scale; - LOG_INFO("upscaling from (%i x %i) to (%i x %i)", - input_image.width, input_image.height, output_width, output_height); +sd::Tensor UpscalerGGML::upscale_tensor(const sd::Tensor& input_tensor) { + sd::Tensor upscaled; + if (tile_size <= 0 || (input_tensor.shape()[0] <= tile_size && input_tensor.shape()[1] <= tile_size)) { + upscaled = esrgan_upscaler->compute(n_threads, input_tensor); + } else { + auto on_processing = [&](const sd::Tensor& input_tile) -> sd::Tensor { + auto output_tile = esrgan_upscaler->compute(n_threads, input_tile); + if (output_tile.empty()) { + LOG_ERROR("esrgan compute failed while processing a tile"); + return {}; + } + return output_tile; + }; - sd::Tensor input_tensor = sd_image_to_tensor(input_image); - sd::Tensor upscaled; - int64_t t0 = ggml_time_ms(); - upscaled = upscale_tensor(input_tensor); - if (upscaled.empty()) { - return upscaled_image; - } - sd_image_t upscaled_data = tensor_to_sd_image(upscaled); - int64_t t3 = ggml_time_ms(); - LOG_INFO("input_image_tensor upscaled, taking %.2fs", (t3 - t0) / 1000.0f); - upscaled_image = upscaled_data; + upscaled = process_tiles_2d(input_tensor, + static_cast(input_tensor.shape()[0] * esrgan_upscaler->scale), + static_cast(input_tensor.shape()[1] * esrgan_upscaler->scale), + esrgan_upscaler->scale, + tile_size, + tile_size, + 0.25f, + false, + false, + on_processing); + } + esrgan_upscaler->free_compute_buffer(); + if (upscaled.empty()) { + LOG_ERROR("esrgan compute failed"); + return {}; + } + return upscaled; +} + +sd_image_t UpscalerGGML::upscale(sd_image_t input_image, uint32_t upscale_factor) { + // upscale_factor, unused for RealESRGAN_x4plus_anime_6B.pth + sd_image_t upscaled_image = {0, 0, 0, nullptr}; + int output_width = (int)input_image.width * esrgan_upscaler->scale; + int output_height = (int)input_image.height * esrgan_upscaler->scale; + LOG_INFO("upscaling from (%i x %i) to (%i x %i)", + input_image.width, input_image.height, output_width, output_height); + + sd::Tensor input_tensor = sd_image_to_tensor(input_image); + sd::Tensor upscaled; + int64_t t0 = ggml_time_ms(); + upscaled = upscale_tensor(input_tensor); + if (upscaled.empty()) { return upscaled_image; } -}; + sd_image_t upscaled_data = tensor_to_sd_image(upscaled); + int64_t t3 = ggml_time_ms(); + LOG_INFO("input_image_tensor upscaled, taking %.2fs", (t3 - t0) / 1000.0f); + upscaled_image = upscaled_data; + return upscaled_image; +} struct upscaler_ctx_t { UpscalerGGML* upscaler = nullptr; diff --git a/otherarch/sdcpp/upscaler.h b/otherarch/sdcpp/upscaler.h new file mode 100644 index 000000000..b11f004a6 --- /dev/null +++ b/otherarch/sdcpp/upscaler.h @@ -0,0 +1,31 @@ +#ifndef __SD_UPSCALER_H__ +#define __SD_UPSCALER_H__ + +#include "esrgan.hpp" +#include "stable-diffusion.h" +#include "tensor.hpp" + +#include +#include + +struct UpscalerGGML { + ggml_backend_t backend = nullptr; // general backend + ggml_type model_data_type = GGML_TYPE_F16; + std::shared_ptr esrgan_upscaler; + std::string esrgan_path; + int n_threads; + bool direct = false; + int tile_size = 128; + + UpscalerGGML(int n_threads, + bool direct = false, + int tile_size = 128); + + bool load_from_file(const std::string& esrgan_path, + bool offload_params_to_cpu, + int n_threads); + sd::Tensor upscale_tensor(const sd::Tensor& input_tensor); + sd_image_t upscale(sd_image_t input_image, uint32_t upscale_factor); +}; + +#endif // __SD_UPSCALER_H__ diff --git a/otherarch/sdcpp/util.cpp b/otherarch/sdcpp/util.cpp index 23e324838..2daa01c09 100644 --- a/otherarch/sdcpp/util.cpp +++ b/otherarch/sdcpp/util.cpp @@ -128,10 +128,10 @@ std::unique_ptr MmapWrapper::create(const std::string& filename) { filename.c_str(), GENERIC_READ, FILE_SHARE_READ, - NULL, + nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, - NULL); + nullptr); if (file_handle == INVALID_HANDLE_VALUE) { return nullptr; @@ -145,16 +145,16 @@ std::unique_ptr MmapWrapper::create(const std::string& filename) { file_size = static_cast(size.QuadPart); - HANDLE mapping_handle = CreateFileMapping(file_handle, NULL, PAGE_READONLY, 0, 0, NULL); + HANDLE mapping_handle = CreateFileMapping(file_handle, nullptr, PAGE_READONLY, 0, 0, nullptr); - if (mapping_handle == NULL) { + if (mapping_handle == nullptr) { CloseHandle(file_handle); return nullptr; } mapped_data = MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_size); - if (mapped_data == NULL) { + if (mapped_data == nullptr) { CloseHandle(mapping_handle); CloseHandle(file_handle); return nullptr; @@ -217,7 +217,7 @@ std::unique_ptr MmapWrapper::create(const std::string& filename) { size_t file_size = sb.st_size; - void* mapped_data = mmap(NULL, file_size, PROT_READ, mmap_flags, file_descriptor, 0); + void* mapped_data = mmap(nullptr, file_size, PROT_READ, mmap_flags, file_descriptor, 0); close(file_descriptor); diff --git a/otherarch/sdcpp/vae.hpp b/otherarch/sdcpp/vae.hpp index dc69535e8..54bd88abf 100644 --- a/otherarch/sdcpp/vae.hpp +++ b/otherarch/sdcpp/vae.hpp @@ -142,9 +142,10 @@ public: "vae encode compute failed while processing a tile"); } else { output = _compute(n_threads, input, false); - free_compute_buffer(); } + free_compute_buffer(); + if (output.empty()) { LOG_ERROR("vae encode compute failed"); return {};