refactor and clean identifiers for sd, fix cmake

This commit is contained in:
Concedo 2024-02-29 18:28:45 +08:00
parent 66134bb36e
commit 5a44d4de2b
9 changed files with 69 additions and 134 deletions

View file

@ -80,9 +80,8 @@ if (LLAMA_CUBLAS)
enable_language(CUDA)
add_compile_definitions(GGML_USE_CUBLAS)
#add_compile_definitions(GGML_CUDA_CUBLAS) #remove to not use cublas
add_compile_definitions(SD_USE_CUBLAS)
add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
#add_compile_definitions(GGML_CUDA_FORCE_DMMV) #non dmmv broken for me
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
@ -150,7 +149,7 @@ if (LLAMA_HIPBLAS)
if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS SD_USE_CUBLAS)
add_library(ggml-rocm OBJECT ${GGML_SOURCES_CUDA})
if (LLAMA_CUDA_FORCE_DMMV)
target_compile_definitions(ggml-rocm PUBLIC GGML_CUDA_FORCE_DMMV)
@ -425,14 +424,21 @@ add_library(common2
common/common.h
common/grammar-parser.h
common/grammar-parser.cpp)
target_include_directories(common2 PUBLIC . ./otherarch ./otherarch/tools ./examples ./common)
target_include_directories(common2 PUBLIC . ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
target_compile_features(common2 PUBLIC cxx_std_11) # don't bump
target_link_libraries(common2 PRIVATE ggml ${LLAMA_EXTRA_LIBS})
set_target_properties(common2 PROPERTIES POSITION_INDEPENDENT_CODE ON)
add_library(sdtype_adapter
sdtype_adapter.cpp)
target_include_directories(sdtype_adapter PUBLIC . ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
target_compile_features(sdtype_adapter PUBLIC cxx_std_11) # don't bump
target_link_libraries(sdtype_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
set_target_properties(sdtype_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
add_library(gpttype_adapter
gpttype_adapter.cpp)
target_include_directories(gpttype_adapter PUBLIC . ./otherarch ./otherarch/tools ./examples ./common)
target_include_directories(gpttype_adapter PUBLIC . ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
target_compile_features(gpttype_adapter PUBLIC cxx_std_11) # don't bump
target_link_libraries(gpttype_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
set_target_properties(gpttype_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
@ -440,24 +446,24 @@ set_target_properties(gpttype_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
if (LLAMA_CUBLAS)
set(TARGET koboldcpp_cublas)
add_library(${TARGET} SHARED expose.cpp expose.h)
target_include_directories(${TARGET} PUBLIC . ./otherarch ./otherarch/tools ./examples ./common)
target_include_directories(${TARGET} PUBLIC . ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
target_compile_features(${TARGET} PUBLIC cxx_std_11) # don't bump
set_target_properties(${TARGET} PROPERTIES PREFIX "")
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "koboldcpp_cublas")
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter ${LLAMA_EXTRA_LIBS})
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
endif()
if (LLAMA_HIPBLAS)
set(TARGET koboldcpp_hipblas)
add_library(${TARGET} SHARED expose.cpp expose.h)
target_include_directories(${TARGET} PUBLIC . ./otherarch ./otherarch/tools ./examples ./common)
target_include_directories(${TARGET} PUBLIC . ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
target_compile_features(${TARGET} PUBLIC cxx_std_11) # don't bump
set_target_properties(${TARGET} PROPERTIES PREFIX "")
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "koboldcpp_hipblas")
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter ${LLAMA_EXTRA_LIBS})
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
endif()

View file

@ -43,6 +43,7 @@ CFLAGS = -I. -I./include -I./include/CL -I./otherarch -I./otherarch
CXXFLAGS = -I. -I./common -I./include -I./include/CL -I./otherarch -I./otherarch/tools -I./otherarch/sdcpp -I./otherarch/sdcpp/thirdparty -I./include/vulkan -O3 -DNDEBUG -std=c++11 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE
LDFLAGS =
FASTCFLAGS = $(subst -O3,-Ofast,$(CFLAGS))
FASTCXXFLAGS = $(subst -O3,-Ofast,$(CXXFLAGS))
# these are used on windows, to build some libraries with extra old device compatibility
SIMPLECFLAGS =
@ -54,7 +55,7 @@ CLBLAST_FLAGS = -DGGML_USE_CLBLAST
FAILSAFE_FLAGS = -DUSE_FAILSAFE
VULKAN_FLAGS = -DGGML_USE_VULKAN
ifdef LLAMA_CUBLAS
CUBLAS_FLAGS = -DGGML_USE_CUBLAS
CUBLAS_FLAGS = -DGGML_USE_CUBLAS -DSD_USE_CUBLAS
else
CUBLAS_FLAGS =
endif
@ -141,7 +142,7 @@ endif
# it is recommended to use the CMAKE file to build for cublas if you can - will likely work better
ifdef LLAMA_CUBLAS
CUBLAS_FLAGS = -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CUBLAS_FLAGS = -DGGML_USE_CUBLAS -DSD_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CUBLASLD_FLAGS = -lcuda -lcublas -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
CUBLAS_OBJS = ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
NVCC = nvcc
@ -225,7 +226,7 @@ ifdef LLAMA_HIPBLAS
LLAMA_CUDA_DMMV_X ?= 32
LLAMA_CUDA_MMV_Y ?= 1
LLAMA_CUDA_KQUANTS_ITER ?= 2
HIPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
HIPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS -DSD_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
HIPLDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib -lhipblas -lamdhip64 -lrocblas
HIP_OBJS += ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
ggml-cuda.o: HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) \
@ -256,8 +257,8 @@ endif # LLAMA_HIPBLAS
ifdef LLAMA_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
CXXFLAGS += -DGGML_USE_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG -DSD_USE_METAL
CXXFLAGS += -DGGML_USE_METAL -DSD_USE_METAL
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
OBJS += ggml-metal.o
@ -479,8 +480,10 @@ expose.o: expose.cpp expose.h
$(CXX) $(CXXFLAGS) -c $< -o $@
# sd.cpp objects
sdcpp_default.o: otherarch/sdcpp/sd_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
$(CXX) $(CXXFLAGS) -c $< -o $@
sdcpp_default.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
$(CXX) $(FASTCXXFLAGS) -c $< -o $@
sdcpp_cublas.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
$(CXX) $(FASTCXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
# idiotic "for easier compilation"
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp llama.cpp otherarch/utils.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml.h ggml-cuda.h llama.h otherarch/llama-util.h

View file

@ -211,11 +211,11 @@ extern "C"
return gpttype_generate(inputs);
}
bool load_model_sd(const load_sd_model_inputs inputs)
bool sd_load_model(const sd_load_model_inputs inputs)
{
return sdtype_load_model(inputs);
}
sd_generation_outputs generate_sd(const sd_generation_inputs inputs)
sd_generation_outputs sd_generate(const sd_generation_inputs inputs)
{
return sdtype_generate(inputs);
}

View file

@ -99,7 +99,7 @@ struct token_count_outputs
int count = 0;
int * ids; //we'll just use shared memory for this one, bit of a hack
};
struct load_sd_model_inputs
struct sd_load_model_inputs
{
const char * model_filename;
const int debugmode = 0;
@ -116,6 +116,7 @@ struct sd_generation_inputs
struct sd_generation_outputs
{
int status = -1;
unsigned int data_length = 0;
const char * data;
};

View file

@ -5,8 +5,9 @@ Kobold Lite WebUI is a standalone WebUI for use with KoboldAI United, AI Horde,
It requires no dependencies, installation or setup.
Just copy this single static HTML file anywhere and open it in a browser, or from a webserver.
Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite.
If you are submitting a pull request for Lite, PLEASE use the above repo, not the KoboldCpp one.
Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line.
Current version: 116
Current version: 117
-Concedo
-->
@ -3254,7 +3255,7 @@ Current version: 116
//casualwriter casual-markdown, under MIT license
function simpleMarkdown(e){var r=function(e){return e.replace(/</g,"<").replace(/\>/g,">")},l=function(e,r){return"<pre><code>"+(r=(r=(r=(r=(r=r.replace(/</g,"&lt;").replace(/\>/g,"&gt;")).replace(/\t/g," ").replace(/\^\^\^(.+?)\^\^\^/g,"<mark>$1</mark>")).replace(/^\/\/(.*)/gm,"<rem>//$1</rem>").replace(/\s\/\/(.*)/gm," <rem>//$1</rem>")).replace(/(\s?)(function|procedure|return|exit|if|then|else|end|loop|while|or|and|case|when)(\s)/gim,"$1<b>$2</b>$3")).replace(/(\s?)(var|let|const|=>|for|next|do|while|loop|continue|break|switch|try|catch|finally)(\s)/gim,"$1<b>$2</b>$3"))+"</code></pre>"},c=function(e){return(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=(e=e.replace(/^###### (.*?)\s*#*$/gm,"<h6>$1</h6>").replace(/^##### (.*?)\s*#*$/gm,"<h5>$1</h5>").replace(/^#### (.*?)\s*#*$/gm,"<h4>$1</h4>").replace(/^### (.*?)\s*#*$/gm,"<h3>$1</h3>").replace(/^## (.*?)\s*#*$/gm,"<h2>$1</h2>").replace(/^# (.*?)\s*#*$/gm,"<h1>$1</h1>")
.replace(/^<h(\d)\>(.*?)\s*{(.*)}\s*<\/h\d\>$/gm,'<h$1 id="$3">$2</h$1>')).replace(/^-{3,}|^\_{3,}|^\*{3,}$/gm,"<hr/>")).replace(/``(.*?)``/gm,function(e,l){return"<code>"+r(l).replace(/`/g,"`")+"</code>"})).replace(/`(.*?)`/gm,"<code>$1</code>")).replace(/^\>\> (.*$)/gm,"<blockquote><blockquote>$1</blockquote></blockquote>")).replace(/^\> (.*$)/gm,"<blockquote>$1</blockquote>")).replace(/<\/blockquote\>\n<blockquote\>/g,"\n")).replace(/<\/blockquote\>\n<blockquote\>/g,"\n<br>")).replace(/!\[(.*?)\]\((.*?) "(.*?)"\)/gm,'<img alt="$1" src="$2" $3 />')).replace(/!\[(.*?)\]\((.*?)\)/gm,'<img alt="$1" src="$2" />')).replace(/\[(.*?)\]\((.*?) "new"\)/gm,'<a href="$2" target=_new>$1</a>')).replace(/\[(.*?)\]\((.*?) "(.*?)"\)/gm,'<a href="$2" title="$3">$1</a>')).replace(/<http(.*?)\>/gm,'<a href="http$1">http$1</a>')).replace(/\[(.*?)\]\(\)/gm,'<a href="$1">$1</a>')).replace(/\[(.*?)\]\((.*?)\)/gm,'<a href="$2">$1</a>'))
.replace(/^[\*+-][ .](.*)/gm,"<ul><li>$1</li></ul>")).replace(/\%SpcEtg\%(\d\d?)[ .](.*)([\n]?)/gm,"\%SpcEtg\%\n$1.$2\n").replace(/^\d\d?[ .](.*)([\n]??)/gm,"<ol><li>$1</li></ol>").replace(/<\/li><\/ol><ol><li>/gm,"</li><li>")).replace(/^<[ou]l><li>(.*\%SpcStg\%.*\%SpcEtg\%.*)<\/li><\/[ou]l\>/gm,"$1").replace(/^\s{2,6}[\*+-][ .](.*)/gm,"<ul><ul><li>$1</li></ul></ul>")).replace(/^\s{2,6}\d[ .](.*)/gm,"<ul><ol><li>$1</li></ol></ul>")).replace(/<\/[ou]l\>\n\n<[ou]l\>/gm,"\n").replace(/<\/[ou]l\>\n<[ou]l\>/g,"")).replace(/<\/[ou]l\>\n<[ou]l\>/g,"\n").replace(/<\/li><\/ul><ul><li>/gm,"</li><li>")).replace(/\*\*\*(\w.*?[^\\])\*\*\*/gm,"<b><em>$1</em></b>")).replace(/\*\*(\w.*?[^\\])\*\*/gm,"<b>$1</b>")).replace(/\*(\w.*?[^\\])\*/gm,"<em>$1</em>")).replace(/___(\w.*?[^\\])___/gm,"<b><em>$1</em></b>")).replace(/__(\w.*?[^\\])__/gm,"<u>$1</u>")).replace(/~~(\w.*?)~~/gm,"<del>$1</del>")).replace(/\^\^(\w.*?)\^\^/gm,"<ins>$1</ins>")).replace(/\{\{(\w.*?)\}\}/gm,"<mark>$1</mark>")).replace(/^((?:\|[^|\r\n]*[^|\r\n\s]\s*)+\|(?:\r?\n|\r|))+/gm,function (matchedTable){return convertMarkdownTableToHtml(matchedTable);})).replace(/ \n/g,"\n<br/>")
.replace(/^[\*+-][ .](.*)/gm,"<ul><li>$1</li></ul>")).replace(/\%SpcEtg\%(\d\d?)[ .](.*)([\n]?)/gm,"\%SpcEtg\%\n$1.$2\n").replace(/^\d\d?[ .] (.*)([\n]??)/gm,"<ol><li>$1</li></ol>").replace(/<\/li><\/ol><ol><li>/gm,"</li><li>")).replace(/^<[ou]l><li>(.*\%SpcStg\%.*\%SpcEtg\%.*)<\/li><\/[ou]l\>/gm,"$1").replace(/^\s{2,6}[\*+-][ .](.*)/gm,"<ul><ul><li>$1</li></ul></ul>")).replace(/^\s{2,6}\d[ .](.*)/gm,"<ul><ol><li>$1</li></ol></ul>")).replace(/<\/[ou]l\>\n\n<[ou]l\>/gm,"\n").replace(/<\/[ou]l\>\n<[ou]l\>/g,"")).replace(/<\/[ou]l\>\n<[ou]l\>/g,"\n").replace(/<\/li><\/ul><ul><li>/gm,"</li><li>")).replace(/\*\*\*(\w.*?[^\\])\*\*\*/gm,"<b><em>$1</em></b>")).replace(/\*\*(\w.*?[^\\])\*\*/gm,"<b>$1</b>")).replace(/\*(\w.*?[^\\])\*/gm,"<em>$1</em>")).replace(/___(\w.*?[^\\])___/gm,"<b><em>$1</em></b>")).replace(/__(\w.*?[^\\])__/gm,"<u>$1</u>")).replace(/~~(\w.*?)~~/gm,"<del>$1</del>")).replace(/\^\^(\w.*?)\^\^/gm,"<ins>$1</ins>")).replace(/\{\{(\w.*?)\}\}/gm,"<mark>$1</mark>")).replace(/^((?:\|[^|\r\n]*[^|\r\n\s]\s*)+\|(?:\r?\n|\r|))+/gm,function (matchedTable){return convertMarkdownTableToHtml(matchedTable);})).replace(/ \n/g,"\n<br/>")
//.replace(/\n\s*\n/g,"\n<p>\n")
).replace(/^ {4,10}(.*)/gm,function(e,l){return"<pre><code>"+r(l)+"</code></pre>"})).replace(/^\t(.*)/gm,function(e,l){return"<pre><code>"+r(l)+"</code></pre>"})).replace(/<\/code\><\/pre\>\n<pre\><code\>/g,"\n")).replace(/\\([`_~\*\+\-\.\^\\\<\>\(\)\[\]])/gm,"$1")},a=0,n=0,p="";for(e=(e=e.replace(/\r\n/g,"\n").replace(/\n~~~/g,"\n```")).replace(/```([^`]+)```/g,l);(a=e.indexOf("<code>"))>=0;)n=e.indexOf("</code>",a),p+=c(e.substr(0,a))+e.substr(a+6,n>0?n-a-6:mdtext.length),e=e.substr(n+7);return p+c(e)}
@ -8119,7 +8120,7 @@ Current version: 116
localsettings.last_selected_preset = document.getElementById("presets").value;
//clean and clamp invalid values
localsettings.max_context_length = cleannum(localsettings.max_context_length, 8, 99999);
localsettings.max_context_length = cleannum(localsettings.max_context_length, 8, 999999);
localsettings.max_length = cleannum(localsettings.max_length, 1, (localsettings.max_context_length-1));
localsettings.temperature = cleannum(localsettings.temperature, 0.01, 5);
localsettings.rep_pen = cleannum(localsettings.rep_pen, 0.1, 5);

View file

@ -22,6 +22,14 @@ logit_bias_max = 16
bias_min_value = -100.0
bias_max_value = 100.0
class logit_bias(ctypes.Structure):
_fields_ = [("token_id", ctypes.c_int32),
("bias", ctypes.c_float)]
class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("ids", ctypes.POINTER(ctypes.c_int))]
class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int),
("blasthreads", ctypes.c_int),
@ -49,10 +57,6 @@ class load_model_inputs(ctypes.Structure):
("banned_tokens", ctypes.c_char_p * ban_token_max),
("tensor_split", ctypes.c_float * tensor_split_max)]
class logit_bias(ctypes.Structure):
_fields_ = [("token_id", ctypes.c_int32),
("bias", ctypes.c_float)]
class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
("prompt", ctypes.c_char_p),
@ -103,12 +107,9 @@ class sd_generation_inputs(ctypes.Structure):
class sd_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("data_length", ctypes.c_uint),
("data", ctypes.c_char_p)]
class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("ids", ctypes.POINTER(ctypes.c_int))]
handle = None
def getdirpath():
@ -273,10 +274,10 @@ def init_library():
handle.abort_generate.restype = ctypes.c_bool
handle.token_count.restype = token_count_outputs
handle.get_pending_output.restype = ctypes.c_char_p
handle.load_model_sd.argtypes = [sd_load_model_inputs]
handle.load_model_sd.restype = ctypes.c_bool
handle.generate_sd.argtypes = [sd_generation_inputs]
handle.generate_sd.restype = sd_generation_outputs
handle.sd_load_model.argtypes = [sd_load_model_inputs]
handle.sd_load_model.restype = ctypes.c_bool
handle.sd_generate.argtypes = [sd_generation_inputs]
handle.sd_generate.restype = sd_generation_outputs
def load_model(model_filename):
global args
@ -469,14 +470,29 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
return outstr
def load_model_sd(model_filename):
def sd_load_model(model_filename):
global args
inputs = sd_load_model_inputs()
inputs.debugmode = args.debugmode
inputs.model_filename = model_filename.encode("UTF-8")
ret = handle.load_model_sd(inputs)
ret = handle.sd_load_model(inputs)
return ret
def sd_generate(prompt, negative_prompt="", cfg_scale=5, sample_steps=20, seed=-1, sample_method="euler a"):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
inputs = sd_generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
inputs.negative_prompt = negative_prompt.encode("UTF-8")
inputs.cfg_scale = cfg_scale
inputs.sample_steps = sample_steps
inputs.seed = seed
inputs.sample_method = sample_method.encode("UTF-8")
ret = handle.sd_generate(inputs)
outstr = ""
if ret.status==1:
outstr = ret.data.decode("UTF-8","ignore")
return outstr
def utfprint(str):
try:
print(str)
@ -2567,7 +2583,7 @@ def main(launch_args,start_server=True):
time.sleep(3)
sys.exit(2)
imgmodel = os.path.abspath(imgmodel)
loadok = load_model_sd(imgmodel)
loadok = sd_load_model(imgmodel)
print("Load Image Model OK: " + str(loadok))
if not loadok:
exitcounter = 999

View file

@ -78,7 +78,7 @@ bool gpttype_generate_abort();
const std::string & gpttype_get_pending_output();
std::vector<int> gpttype_get_token_arr(const std::string & input);
bool sdtype_load_model(const load_sd_model_inputs inputs);
bool sdtype_load_model(const sd_load_model_inputs inputs);
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs);
void timer_start();

View file

@ -1,95 +0,0 @@
cmake_minimum_required(VERSION 3.12)
project("stable-diffusion")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(SD_STANDALONE ON)
else()
set(SD_STANDALONE OFF)
endif()
#
# Option list
#
# general
option(SD_CUBLAS "sd: cuda backend" OFF)
option(SD_HIPBLAS "sd: rocm backend" OFF)
option(SD_METAL "sd: metal backend" OFF)
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
if(SD_CUBLAS)
message("Use CUBLAS as backend stable-diffusion")
set(GGML_CUBLAS ON)
add_definitions(-DSD_USE_CUBLAS)
endif()
if(SD_METAL)
message("Use Metal as backend stable-diffusion")
set(GGML_METAL ON)
add_definitions(-DSD_USE_METAL)
endif()
if (SD_HIPBLAS)
message("Use HIPBLAS as backend stable-diffusion")
set(GGML_HIPBLAS ON)
add_definitions(-DSD_USE_CUBLAS)
if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON)
endif()
endif ()
if(SD_FLASH_ATTN)
message("Use Flash Attention for memory optimization")
add_definitions(-DSD_USE_FLASH_ATTENTION)
endif()
set(SD_LIB stable-diffusion)
file(GLOB SD_LIB_SOURCES
"*.h"
"*.cpp"
"*.hpp"
)
# we can get only one share lib
if(SD_BUILD_SHARED_LIBS)
message("Build shared library")
set(BUILD_SHARED_LIBS OFF)
message(${SD_LIB_SOURCES})
add_library(${SD_LIB} SHARED ${SD_LIB_SOURCES})
add_definitions(-DSD_BUILD_SHARED_LIB)
target_compile_definitions(${SD_LIB} PRIVATE -DSD_BUILD_DLL)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
else()
message("Build static library")
add_library(${SD_LIB} STATIC ${SD_LIB_SOURCES})
endif()
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
# deps
add_subdirectory(ggml)
add_subdirectory(thirdparty)
target_link_libraries(${SD_LIB} PUBLIC ggml zip)
target_include_directories(${SD_LIB} PUBLIC . thirdparty)
target_compile_features(${SD_LIB} PUBLIC cxx_std_11)
add_subdirectory(examples)

View file

@ -125,7 +125,7 @@ static void sd_logger_callback(enum sd_log_level_t level, const char* log, void*
}
}
bool sdtype_load_model(const load_sd_model_inputs inputs) {
bool sdtype_load_model(const sd_load_model_inputs inputs) {
printf("\nSelected Image Model: %s\n",inputs.model_filename);
@ -174,6 +174,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
printf("\nError: KCPP SD is not initialized!\n");
output.data = nullptr;
output.status = 0;
output.data_length = 0;
return output;
}
uint8_t * input_image_buffer = NULL;
@ -233,6 +234,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
printf("\nKCPP SD generate failed!\n");
output.data = nullptr;
output.status = 0;
output.data_length = 0;
return output;
}
@ -255,5 +257,6 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
output.data = nullptr;
output.status = 1;
output.data_length = 0;
return output;
}