mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 01:54:37 +00:00
resync and updated sdcpp for flux and sd3 support
This commit is contained in:
parent
33721615b5
commit
f32a874966
30 changed files with 2434248 additions and 1729 deletions
|
@ -28,6 +28,9 @@
|
|||
#include "ggml.h"
|
||||
#include "stable-diffusion.h"
|
||||
|
||||
// #define STB_IMAGE_RESIZE_IMPLEMENTATION //already defined
|
||||
#include "stb_image_resize.h"
|
||||
|
||||
bool ends_with(const std::string& str, const std::string& ending) {
|
||||
if (str.length() >= ending.length()) {
|
||||
return (str.compare(str.length() - ending.length(), ending.length(), ending) == 0);
|
||||
|
@ -98,6 +101,43 @@ std::string get_full_path(const std::string& dir, const std::string& filename) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> get_files_from_dir(const std::string& dir) {
|
||||
std::vector<std::string> files;
|
||||
|
||||
WIN32_FIND_DATA findFileData;
|
||||
HANDLE hFind;
|
||||
|
||||
char currentDirectory[MAX_PATH];
|
||||
GetCurrentDirectory(MAX_PATH, currentDirectory);
|
||||
|
||||
char directoryPath[MAX_PATH]; // this is absolute path
|
||||
sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str());
|
||||
|
||||
// Find the first file in the directory
|
||||
hFind = FindFirstFile(directoryPath, &findFileData);
|
||||
|
||||
// Check if the directory was found
|
||||
if (hFind == INVALID_HANDLE_VALUE) {
|
||||
printf("Unable to find directory.\n");
|
||||
return files;
|
||||
}
|
||||
|
||||
// Loop through all files in the directory
|
||||
do {
|
||||
// Check if the found file is a regular file (not a directory)
|
||||
if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
||||
files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName));
|
||||
}
|
||||
} while (FindNextFile(hFind, &findFileData) != 0);
|
||||
|
||||
// Close the handle
|
||||
FindClose(hFind);
|
||||
|
||||
sort(files.begin(), files.end());
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
#else // Unix
|
||||
#include <dirent.h>
|
||||
#include <sys/stat.h>
|
||||
|
@ -112,6 +152,7 @@ bool is_directory(const std::string& path) {
|
|||
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
|
||||
}
|
||||
|
||||
// TODO: add windows version
|
||||
std::string get_full_path(const std::string& dir, const std::string& filename) {
|
||||
DIR* dp = opendir(dir.c_str());
|
||||
|
||||
|
@ -131,6 +172,27 @@ std::string get_full_path(const std::string& dir, const std::string& filename) {
|
|||
return "";
|
||||
}
|
||||
|
||||
std::vector<std::string> get_files_from_dir(const std::string& dir) {
|
||||
std::vector<std::string> files;
|
||||
|
||||
DIR* dp = opendir(dir.c_str());
|
||||
|
||||
if (dp != nullptr) {
|
||||
struct dirent* entry;
|
||||
|
||||
while ((entry = readdir(dp)) != nullptr) {
|
||||
std::string fname = dir + "/" + entry->d_name;
|
||||
if (!is_directory(fname))
|
||||
files.push_back(fname);
|
||||
}
|
||||
closedir(dp);
|
||||
}
|
||||
|
||||
sort(files.begin(), files.end());
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// get_num_physical_cores is copy from
|
||||
|
@ -171,6 +233,9 @@ int32_t sd_get_num_physical_cores() {
|
|||
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
||||
}
|
||||
|
||||
static sd_progress_cb_t sd_progress_cb = NULL;
|
||||
void* sd_progress_cb_data = NULL;
|
||||
|
||||
std::u32string utf8_to_utf32(const std::string& utf8_str) {
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
return converter.from_bytes(utf8_str);
|
||||
|
@ -214,9 +279,46 @@ std::string path_join(const std::string& p1, const std::string& p2) {
|
|||
return p1 + "/" + p2;
|
||||
}
|
||||
|
||||
sd_image_t* preprocess_id_image(sd_image_t* img) {
|
||||
int shortest_edge = 224;
|
||||
int size = shortest_edge;
|
||||
sd_image_t* resized = NULL;
|
||||
uint32_t w = img->width;
|
||||
uint32_t h = img->height;
|
||||
uint32_t c = img->channel;
|
||||
|
||||
// 1. do resize using stb_resize functions
|
||||
|
||||
unsigned char* buf = (unsigned char*)malloc(sizeof(unsigned char) * 3 * size * size);
|
||||
if (!stbir_resize_uint8(img->data, w, h, 0,
|
||||
buf, size, size, 0,
|
||||
c)) {
|
||||
fprintf(stderr, "%s: resize operation failed \n ", __func__);
|
||||
return resized;
|
||||
}
|
||||
|
||||
// 2. do center crop (likely unnecessary due to step 1)
|
||||
|
||||
// 3. do rescale
|
||||
|
||||
// 4. do normalize
|
||||
|
||||
// 3 and 4 will need to be done in float format.
|
||||
|
||||
resized = new sd_image_t{(uint32_t)shortest_edge,
|
||||
(uint32_t)shortest_edge,
|
||||
3,
|
||||
buf};
|
||||
return resized;
|
||||
}
|
||||
|
||||
static int sdloglevel = 0; //-1 = hide all, 0 = normal, 1 = showall
|
||||
static bool sdquiet = false;
|
||||
void pretty_progress(int step, int steps, float time) {
|
||||
if (sd_progress_cb) {
|
||||
sd_progress_cb(step, steps, time, sd_progress_cb_data);
|
||||
return;
|
||||
}
|
||||
if (step == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -296,23 +398,13 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo
|
|||
va_list args;
|
||||
va_start(args, format);
|
||||
|
||||
const char* level_str = "DEBUG";
|
||||
if (level == SD_LOG_INFO) {
|
||||
level_str = "INFO ";
|
||||
} else if (level == SD_LOG_WARN) {
|
||||
level_str = "WARN ";
|
||||
} else if (level == SD_LOG_ERROR) {
|
||||
level_str = "ERROR";
|
||||
}
|
||||
|
||||
static char log_buffer[LOG_BUFFER_SIZE];
|
||||
|
||||
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "[%s] %s:%-4d - ", level_str, sd_basename(file).c_str(), line);
|
||||
static char log_buffer[LOG_BUFFER_SIZE + 1];
|
||||
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "%s:%-4d - ", sd_basename(file).c_str(), line);
|
||||
|
||||
if (written >= 0 && written < LOG_BUFFER_SIZE) {
|
||||
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
|
||||
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - strlen(log_buffer) - 1);
|
||||
}
|
||||
strncat(log_buffer, "\n", LOG_BUFFER_SIZE - strlen(log_buffer));
|
||||
|
||||
if (sd_log_cb) {
|
||||
sd_log_cb(level, log_buffer, sd_log_cb_data);
|
||||
|
@ -325,7 +417,10 @@ void sd_set_log_callback(sd_log_cb_t cb, void* data) {
|
|||
sd_log_cb = cb;
|
||||
sd_log_cb_data = data;
|
||||
}
|
||||
|
||||
void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
|
||||
sd_progress_cb = cb;
|
||||
sd_progress_cb_data = data;
|
||||
}
|
||||
const char* sd_get_system_info() {
|
||||
static char buffer[1024];
|
||||
std::stringstream ss;
|
||||
|
@ -499,4 +594,111 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
|
|||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345
|
||||
//
|
||||
// Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
// Accepted tokens are:
|
||||
// (abc) - increases attention to abc by a multiplier of 1.1
|
||||
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
// [abc] - decreases attention to abc by a multiplier of 1.1
|
||||
// \( - literal character '('
|
||||
// \[ - literal character '['
|
||||
// \) - literal character ')'
|
||||
// \] - literal character ']'
|
||||
// \\ - literal character '\'
|
||||
// anything else - just text
|
||||
//
|
||||
// >>> parse_prompt_attention('normal text')
|
||||
// [['normal text', 1.0]]
|
||||
// >>> parse_prompt_attention('an (important) word')
|
||||
// [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
// >>> parse_prompt_attention('(unbalanced')
|
||||
// [['unbalanced', 1.1]]
|
||||
// >>> parse_prompt_attention('\(literal\]')
|
||||
// [['(literal]', 1.0]]
|
||||
// >>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
// [['unnecessaryparens', 1.1]]
|
||||
// >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
// [['a ', 1.0],
|
||||
// ['house', 1.5730000000000004],
|
||||
// [' ', 1.1],
|
||||
// ['on', 1.0],
|
||||
// [' a ', 1.1],
|
||||
// ['hill', 0.55],
|
||||
// [', sun, ', 1.1],
|
||||
// ['sky', 1.4641000000000006],
|
||||
// ['.', 1.1]]
|
||||
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text) {
|
||||
std::vector<std::pair<std::string, float>> res;
|
||||
std::vector<int> round_brackets;
|
||||
std::vector<int> square_brackets;
|
||||
|
||||
float round_bracket_multiplier = 1.1f;
|
||||
float square_bracket_multiplier = 1 / 1.1f;
|
||||
|
||||
std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
|
||||
std::regex re_break(R"(\s*\bBREAK\b\s*)");
|
||||
|
||||
auto multiply_range = [&](int start_position, float multiplier) {
|
||||
for (int p = start_position; p < res.size(); ++p) {
|
||||
res[p].second *= multiplier;
|
||||
}
|
||||
};
|
||||
|
||||
std::smatch m;
|
||||
std::string remaining_text = text;
|
||||
|
||||
while (std::regex_search(remaining_text, m, re_attention)) {
|
||||
std::string text = m[0];
|
||||
std::string weight = m[1];
|
||||
|
||||
if (text == "(") {
|
||||
round_brackets.push_back((int)res.size());
|
||||
} else if (text == "[") {
|
||||
square_brackets.push_back((int)res.size());
|
||||
} else if (!weight.empty()) {
|
||||
if (!round_brackets.empty()) {
|
||||
multiply_range(round_brackets.back(), std::stof(weight));
|
||||
round_brackets.pop_back();
|
||||
}
|
||||
} else if (text == ")" && !round_brackets.empty()) {
|
||||
multiply_range(round_brackets.back(), round_bracket_multiplier);
|
||||
round_brackets.pop_back();
|
||||
} else if (text == "]" && !square_brackets.empty()) {
|
||||
multiply_range(square_brackets.back(), square_bracket_multiplier);
|
||||
square_brackets.pop_back();
|
||||
} else if (text == "\\(") {
|
||||
res.push_back({text.substr(1), 1.0f});
|
||||
} else {
|
||||
res.push_back({text, 1.0f});
|
||||
}
|
||||
|
||||
remaining_text = m.suffix();
|
||||
}
|
||||
|
||||
for (int pos : round_brackets) {
|
||||
multiply_range(pos, round_bracket_multiplier);
|
||||
}
|
||||
|
||||
for (int pos : square_brackets) {
|
||||
multiply_range(pos, square_bracket_multiplier);
|
||||
}
|
||||
|
||||
if (res.empty()) {
|
||||
res.push_back({"", 1.0f});
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
while (i + 1 < res.size()) {
|
||||
if (res[i].second == res[i + 1].second) {
|
||||
res[i].first += res[i + 1].first;
|
||||
res.erase(res.begin() + i + 1);
|
||||
} else {
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue