re-added smart context due to people complaining

This commit is contained in:
Concedo 2024-05-11 17:25:03 +08:00
parent 702be65ed1
commit eff01660e4
3 changed files with 20 additions and 8 deletions

View file

@ -44,6 +44,7 @@ struct load_model_inputs
const char * mmproj_filename; const char * mmproj_filename;
const bool use_mmap; const bool use_mmap;
const bool use_mlock; const bool use_mlock;
const bool use_smartcontext;
const bool use_contextshift; const bool use_contextshift;
const int clblast_info = 0; const int clblast_info = 0;
const int cublas_info = 0; const int cublas_info = 0;

View file

@ -92,6 +92,7 @@ static int current_llava_identifier = LLAVA_TOKEN_IDENTIFIER_A;
static gpt_params * kcpp_params = nullptr; static gpt_params * kcpp_params = nullptr;
static int max_context_limit_at_load = 0; static int max_context_limit_at_load = 0;
static int n_past = 0; static int n_past = 0;
static bool useSmartContext = false;
static bool useContextShift = false; static bool useContextShift = false;
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
static std::string modelname; static std::string modelname;
@ -786,6 +787,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
} }
kcpp_params->flash_attn = inputs.flash_attention; kcpp_params->flash_attn = inputs.flash_attention;
modelname = kcpp_params->model = inputs.model_filename; modelname = kcpp_params->model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
useContextShift = inputs.use_contextshift; useContextShift = inputs.use_contextshift;
debugmode = inputs.debugmode; debugmode = inputs.debugmode;
@ -1939,7 +1941,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
} }
else else
{ {
bool triggersc = useContextShift; bool triggersc = useSmartContext;
if(useContextShift && (file_format == FileFormat::GGUF_GENERIC)) if(useContextShift && (file_format == FileFormat::GGUF_GENERIC))
{ {
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx); PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);

View file

@ -45,6 +45,7 @@ class load_model_inputs(ctypes.Structure):
("mmproj_filename", ctypes.c_char_p), ("mmproj_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool), ("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool), ("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
("use_contextshift", ctypes.c_bool), ("use_contextshift", ctypes.c_bool),
("clblast_info", ctypes.c_int), ("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int), ("cublas_info", ctypes.c_int),
@ -371,6 +372,7 @@ def load_model(model_filename):
inputs.lora_base = args.lora[1].encode("UTF-8") inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8") inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1) inputs.use_contextshift = (0 if args.noshift else 1)
inputs.flash_attention = args.flashattention inputs.flash_attention = args.flashattention
inputs.blasbatchsize = args.blasbatchsize inputs.blasbatchsize = args.blasbatchsize
@ -1673,6 +1675,7 @@ def show_new_gui():
contextshift = ctk.IntVar(value=1) contextshift = ctk.IntVar(value=1)
remotetunnel = ctk.IntVar(value=0) remotetunnel = ctk.IntVar(value=0)
smartcontext = ctk.IntVar()
flashattention = ctk.IntVar(value=0) flashattention = ctk.IntVar(value=0)
context_var = ctk.IntVar() context_var = ctk.IntVar()
customrope_var = ctk.IntVar() customrope_var = ctk.IntVar()
@ -1956,7 +1959,10 @@ def show_new_gui():
gpulayers_var.trace("w", changed_gpulayers) gpulayers_var.trace("w", changed_gpulayers)
def togglectxshift(a,b,c): def togglectxshift(a,b,c):
pass if contextshift.get()==0:
smartcontextbox.grid(row=1, column=0, padx=8, pady=1, stick="nw")
else:
smartcontextbox.grid_forget()
def guibench(): def guibench():
args.benchmark = "stdout" args.benchmark = "stdout"
@ -2110,6 +2116,7 @@ def show_new_gui():
# Tokens Tab # Tokens Tab
tokens_tab = tabcontent["Tokens"] tokens_tab = tabcontent["Tokens"]
# tokens checkboxes # tokens checkboxes
smartcontextbox = makecheckbox(tokens_tab, "Use SmartContext", smartcontext, 1,tooltiptxt="Uses SmartContext. Now considered outdated and not recommended.\nCheck the wiki for more info.")
makecheckbox(tokens_tab, "Use ContextShift", contextshift, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift) makecheckbox(tokens_tab, "Use ContextShift", contextshift, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift)
togglectxshift(1,1,1) togglectxshift(1,1,1)
@ -2206,6 +2213,7 @@ def show_new_gui():
args.launch = launchbrowser.get()==1 args.launch = launchbrowser.get()==1
args.highpriority = highpriority.get()==1 args.highpriority = highpriority.get()==1
args.nommap = disablemmap.get()==1 args.nommap = disablemmap.get()==1
args.smartcontext = smartcontext.get()==1
args.flashattention = flashattention.get()==1 args.flashattention = flashattention.get()==1
args.noshift = contextshift.get()==0 args.noshift = contextshift.get()==0
args.remotetunnel = remotetunnel.get()==1 args.remotetunnel = remotetunnel.get()==1
@ -2301,6 +2309,7 @@ def show_new_gui():
launchbrowser.set(1 if "launch" in dict and dict["launch"] else 0) launchbrowser.set(1 if "launch" in dict and dict["launch"] else 0)
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0) highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0) disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0) flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1) contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1)
remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0) remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
@ -2730,7 +2739,7 @@ def check_deprecation_warning():
if using_outdated_flags: if using_outdated_flags:
print(f"\n=== !!! IMPORTANT WARNING !!! ===") print(f"\n=== !!! IMPORTANT WARNING !!! ===")
print("You are using one or more OUTDATED config files or launch flags!") print("You are using one or more OUTDATED config files or launch flags!")
print("The flags --smartcontext, --hordeconfig and --sdconfig have been DEPRECATED, and MAY be REMOVED in future!") print("The flags --hordeconfig and --sdconfig have been DEPRECATED, and MAY be REMOVED in future!")
print("They will still work for now, but you SHOULD switch to the updated flags instead, to avoid future issues!") print("They will still work for now, but you SHOULD switch to the updated flags instead, to avoid future issues!")
print("New flags are: --hordemodelname --hordeworkername --hordekey --hordemaxctx --hordegenlen --sdmodel --sdthreads --sdquant --sdclamped") print("New flags are: --hordemodelname --hordeworkername --hordekey --hordemaxctx --hordegenlen --sdmodel --sdthreads --sdquant --sdclamped")
print("For more information on these flags, please check --help") print("For more information on these flags, please check --help")
@ -3393,11 +3402,7 @@ if __name__ == '__main__':
advparser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="") advparser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")
advparser.add_argument("--flashattention", help="Enables flash attention (Experimental).", action='store_true') advparser.add_argument("--flashattention", help="Enables flash attention (Experimental).", action='store_true')
advparser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0) advparser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
advparser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently. Not recommended.", action='store_true')
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
deprecatedgroup.add_argument("--smartcontext", help="Command is DEPRECATED and should NOT be used! Instead, use --noshift instead to toggle smartcontext off on old GGML models.", action='store_true')
deprecatedgroup.add_argument("--hordeconfig", help="Command is DEPRECATED and should NOT be used! Instead, use non-positional flags --hordemodelname --hordeworkername --hordekey --hordemaxctx --hordegenlen instead.", nargs='+')
deprecatedgroup.add_argument("--sdconfig", help="Command is DEPRECATED and should NOT be used! Instead, use non-positional flags --sdmodel --sdthreads --sdquant --sdclamped instead.", nargs='+')
hordeparsergroup = parser.add_argument_group('Horde Worker Commands') hordeparsergroup = parser.add_argument_group('Horde Worker Commands')
hordeparsergroup.add_argument("--hordemodelname", metavar=('[name]'), help="Sets your AI Horde display model name.", default="") hordeparsergroup.add_argument("--hordemodelname", metavar=('[name]'), help="Sets your AI Horde display model name.", default="")
@ -3412,4 +3417,8 @@ if __name__ == '__main__':
sdparsergroup.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true') sdparsergroup.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use.", action='store_true') sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use.", action='store_true')
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+')
deprecatedgroup.add_argument("--sdconfig", help=argparse.SUPPRESS, nargs='+')
main(parser.parse_args(),start_server=True) main(parser.parse_args(),start_server=True)