wip submitting of llava image to backend

This commit is contained in:
Concedo 2024-03-10 17:14:27 +08:00
parent 6990d07a26
commit d943c739a8
5 changed files with 51 additions and 33 deletions

View file

@ -273,7 +273,7 @@ class model_backend(InferenceModel):
unbantokens=False, bantokens=None, usemirostat=None, forceversion=0, nommap=self.kcpp_nommap, unbantokens=False, bantokens=None, usemirostat=None, forceversion=0, nommap=self.kcpp_nommap,
usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas, usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas,
useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, usevulkan=self.kcpp_usevulkan, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None, useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, usevulkan=self.kcpp_usevulkan, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None,
onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=False, nocertify=False, sdconfig=None) onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=False, nocertify=False, sdconfig=None, mmproj=None)
#koboldcpp.main(kcppargs,False) #initialize library without enabling Lite http server #koboldcpp.main(kcppargs,False) #initialize library without enabling Lite http server

View file

@ -34,6 +34,7 @@ extern "C"
std::string model = inputs.model_filename; std::string model = inputs.model_filename;
lora_filename = inputs.lora_filename; lora_filename = inputs.lora_filename;
lora_base = inputs.lora_base; lora_base = inputs.lora_base;
mmproj_filename = inputs.mmproj_filename;
int forceversion = inputs.forceversion; int forceversion = inputs.forceversion;

View file

@ -41,6 +41,7 @@ struct load_model_inputs
const char * model_filename; const char * model_filename;
const char * lora_filename; const char * lora_filename;
const char * lora_base; const char * lora_base;
const char * mmproj_filename;
const bool use_mmap; const bool use_mmap;
const bool use_mlock; const bool use_mlock;
const bool use_smartcontext; const bool use_smartcontext;
@ -133,6 +134,7 @@ struct sd_generation_outputs
extern std::string executable_path; extern std::string executable_path;
extern std::string lora_filename; extern std::string lora_filename;
extern std::string lora_base; extern std::string lora_base;
extern std::string mmproj_filename;
extern std::vector<std::string> generated_tokens; extern std::vector<std::string> generated_tokens;
extern bool generation_finished; extern bool generation_finished;
extern float last_eval_time; extern float last_eval_time;

View file

@ -618,6 +618,24 @@ static void load_grammar(const std::string & gammarstr)
} }
} }
static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < num_img_tokens; i += n_batch) {
int n_eval = num_img_tokens - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "\n%s : failed to eval image\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}
//given an old GGUF context and a new context that has some middle portion removed, //given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action //find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx) void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
@ -1064,6 +1082,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
if(mmproj_filename != "") if(mmproj_filename != "")
{ {
printf("\nAttempting to apply Multimodal Projector: %s\n", mmproj_filename.c_str());
clp_ctx = clip_model_load(mmproj_filename.c_str(), /*verbosity=*/ 1); clp_ctx = clip_model_load(mmproj_filename.c_str(), /*verbosity=*/ 1);
if(clp_ctx == nullptr) { if(clp_ctx == nullptr) {
fprintf(stderr, "%s: error: failed to load mmproj model!\n", __func__); fprintf(stderr, "%s: error: failed to load mmproj model!\n", __func__);
@ -1672,34 +1691,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
} }
// for (int i = 0; i < img.image_tokens; i += n_batch)
// {
// int n_eval = img.image_tokens - i;
// if (n_eval > n_batch)
// {
// n_eval = n_batch;
// }
// const int n_embd = llama_n_embd(model);
// llama_batch batch_img = {
// n_eval,
// nullptr,
// (img.image_embedding + i * n_embd),
// nullptr,
// nullptr,
// nullptr,
// nullptr,
// slot.n_past,
// 1, 0
// };
// if (llama_decode(ctx, batch_img))
// {
// LOG_TEE("%s : failed to eval image\n", __func__);
// return false;
// }
// slot.n_past += n_eval;
// }
if(addedmemory!="") if(addedmemory!="")
{ {
TokenizeString(addedmemory, embd_inp_mem, file_format); TokenizeString(addedmemory, embd_inp_mem, file_format);

View file

@ -42,6 +42,7 @@ class load_model_inputs(ctypes.Structure):
("model_filename", ctypes.c_char_p), ("model_filename", ctypes.c_char_p),
("lora_filename", ctypes.c_char_p), ("lora_filename", ctypes.c_char_p),
("lora_base", ctypes.c_char_p), ("lora_base", ctypes.c_char_p),
("mmproj_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool), ("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool), ("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool), ("use_smartcontext", ctypes.c_bool),
@ -352,6 +353,8 @@ def load_model(model_filename):
inputs.use_mmap = False inputs.use_mmap = False
if len(args.lora) > 1: if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8") inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
inputs.use_smartcontext = args.smartcontext inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1) inputs.use_contextshift = (0 if args.noshift else 1)
inputs.blasbatchsize = args.blasbatchsize inputs.blasbatchsize = args.blasbatchsize
@ -590,6 +593,7 @@ def bring_terminal_to_foreground():
friendlymodelname = "inactive" friendlymodelname = "inactive"
friendlysdmodelname = "inactive" friendlysdmodelname = "inactive"
fullsdmodelpath = "" #if empty, it's not initialized fullsdmodelpath = "" #if empty, it's not initialized
mmprojpath = "" #if empty, it's not initialized
maxctx = 2048 maxctx = 2048
maxhordectx = 2048 maxhordectx = 2048
maxhordelen = 256 maxhordelen = 256
@ -938,7 +942,7 @@ Enter Prompt:<br>
self.wfile.write(finalhtml) self.wfile.write(finalhtml)
def do_GET(self): def do_GET(self):
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, mmprojpath
self.path = self.path.rstrip('/') self.path = self.path.rstrip('/')
response_body = None response_body = None
content_type = 'application/json' content_type = 'application/json'
@ -976,7 +980,9 @@ Enter Prompt:<br>
response_body = (json.dumps({"value": maxctx}).encode()) response_body = (json.dumps({"value": maxctx}).encode())
elif self.path.endswith(('/api/extra/version')): elif self.path.endswith(('/api/extra/version')):
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode()) has_txt2img = not (friendlysdmodelname=="inactive" or fullsdmodelpath=="")
has_vision = (mmprojpath!="")
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion,"txt2img":has_txt2img,"vision":has_vision}).encode())
elif self.path.endswith(('/api/extra/perf')): elif self.path.endswith(('/api/extra/perf')):
lastp = handle.get_last_process_time() lastp = handle.get_last_process_time()
@ -1434,6 +1440,7 @@ def show_new_gui():
lora_var = ctk.StringVar() lora_var = ctk.StringVar()
lora_base_var = ctk.StringVar() lora_base_var = ctk.StringVar()
preloadstory_var = ctk.StringVar() preloadstory_var = ctk.StringVar()
mmproj_var = ctk.StringVar()
port_var = ctk.StringVar(value=defaultport) port_var = ctk.StringVar(value=defaultport)
host_var = ctk.StringVar(value="") host_var = ctk.StringVar(value="")
@ -1882,7 +1889,8 @@ def show_new_gui():
makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=on_picked_model_file,tooltiptxt="Select a GGUF or GGML model file on disk to be loaded.") makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=on_picked_model_file,tooltiptxt="Select a GGUF or GGML model file on disk to be loaded.")
makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3,tooltiptxt="Select an optional GGML LoRA adapter to use.\nLeave blank to skip.") makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3,tooltiptxt="Select an optional GGML LoRA adapter to use.\nLeave blank to skip.")
makefileentry(model_tab, "Lora Base:", "Select Lora Base File", lora_base_var, 5,tooltiptxt="Select an optional F16 GGML LoRA base file to use.\nLeave blank to skip.") makefileentry(model_tab, "Lora Base:", "Select Lora Base File", lora_base_var, 5,tooltiptxt="Select an optional F16 GGML LoRA base file to use.\nLeave blank to skip.")
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 7,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.") makefileentry(model_tab, "LLaVA mmproj:", "Select LLaVA mmproj File", mmproj_var, 7,tooltiptxt="Select a mmproj file to use for LLaVA.\nLeave blank to skip.")
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 9,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.")
# Network Tab # Network Tab
network_tab = tabcontent["Network"] network_tab = tabcontent["Network"]
@ -2006,6 +2014,7 @@ def show_new_gui():
args.model_param = None if model_var.get() == "" else model_var.get() args.model_param = None if model_var.get() == "" else model_var.get()
args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()]) args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()])
args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get() args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get()
args.mmproj = None if mmproj_var.get() == "" else mmproj_var.get()
args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()]) args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()])
@ -2121,6 +2130,9 @@ def show_new_gui():
else: else:
lora_var.set(dict["lora"][0]) lora_var.set(dict["lora"][0])
if "mmproj" in dict and dict["mmproj"]:
mmproj_var.set(dict["mmproj"])
if "ssl" in dict and dict["ssl"]: if "ssl" in dict and dict["ssl"]:
if len(dict["ssl"]) == 2: if len(dict["ssl"]) == 2:
ssl_cert_var.set(dict["ssl"][0]) ssl_cert_var.set(dict["ssl"][0])
@ -2572,7 +2584,7 @@ def sanitize_string(input_string):
return sanitized_string return sanitized_string
def main(launch_args,start_server=True): def main(launch_args,start_server=True):
global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, mmprojpath
args = launch_args args = launch_args
embedded_kailite = None embedded_kailite = None
embedded_kcpp_docs = None embedded_kcpp_docs = None
@ -2696,6 +2708,17 @@ def main(launch_args,start_server=True):
else: else:
args.lora[1] = os.path.abspath(args.lora[1]) args.lora[1] = os.path.abspath(args.lora[1])
if args.mmproj and args.mmproj!="":
if not os.path.exists(args.mmproj):
exitcounter = 999
print(f"Cannot find mmproj file: {args.mmproj}")
time.sleep(3)
sys.exit(2)
else:
global mmprojpath
args.mmproj = os.path.abspath(args.mmproj)
mmprojpath = args.mmproj
if not args.blasthreads or args.blasthreads <= 0: if not args.blasthreads or args.blasthreads <= 0:
args.blasthreads = args.threads args.blasthreads = args.threads
@ -2943,5 +2966,6 @@ if __name__ == '__main__':
parser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+') parser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+')
parser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true') parser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick|clamped] [threads] [quant|noquant]'), nargs='+') parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick|clamped] [threads] [quant|noquant]'), nargs='+')
parser.add_argument("--mmproj", help="Select a multimodal projector file for LLaVA.", default="")
main(parser.parse_args(),start_server=True) main(parser.parse_args(),start_server=True)