re-added smart context due to people complaining

This commit is contained in:
Concedo 2024-05-11 17:25:03 +08:00
parent 702be65ed1
commit eff01660e4
3 changed files with 20 additions and 8 deletions

View file

@ -44,6 +44,7 @@ struct load_model_inputs
const char * mmproj_filename;
const bool use_mmap;
const bool use_mlock;
const bool use_smartcontext;
const bool use_contextshift;
const int clblast_info = 0;
const int cublas_info = 0;

View file

@ -92,6 +92,7 @@ static int current_llava_identifier = LLAVA_TOKEN_IDENTIFIER_A;
static gpt_params * kcpp_params = nullptr;
static int max_context_limit_at_load = 0;
static int n_past = 0;
static bool useSmartContext = false;
static bool useContextShift = false;
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
static std::string modelname;
@ -786,6 +787,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
kcpp_params->flash_attn = inputs.flash_attention;
modelname = kcpp_params->model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
useContextShift = inputs.use_contextshift;
debugmode = inputs.debugmode;
@ -1939,7 +1941,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
else
{
bool triggersc = useContextShift;
bool triggersc = useSmartContext;
if(useContextShift && (file_format == FileFormat::GGUF_GENERIC))
{
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);

View file

@ -45,6 +45,7 @@ class load_model_inputs(ctypes.Structure):
("mmproj_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
("use_contextshift", ctypes.c_bool),
("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int),
@ -371,6 +372,7 @@ def load_model(model_filename):
inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1)
inputs.flash_attention = args.flashattention
inputs.blasbatchsize = args.blasbatchsize
@ -1673,6 +1675,7 @@ def show_new_gui():
contextshift = ctk.IntVar(value=1)
remotetunnel = ctk.IntVar(value=0)
smartcontext = ctk.IntVar()
flashattention = ctk.IntVar(value=0)
context_var = ctk.IntVar()
customrope_var = ctk.IntVar()
@ -1956,7 +1959,10 @@ def show_new_gui():
gpulayers_var.trace("w", changed_gpulayers)
def togglectxshift(a,b,c):
pass
if contextshift.get()==0:
smartcontextbox.grid(row=1, column=0, padx=8, pady=1, stick="nw")
else:
smartcontextbox.grid_forget()
def guibench():
args.benchmark = "stdout"
@ -2110,6 +2116,7 @@ def show_new_gui():
# Tokens Tab
tokens_tab = tabcontent["Tokens"]
# tokens checkboxes
smartcontextbox = makecheckbox(tokens_tab, "Use SmartContext", smartcontext, 1,tooltiptxt="Uses SmartContext. Now considered outdated and not recommended.\nCheck the wiki for more info.")
makecheckbox(tokens_tab, "Use ContextShift", contextshift, 2,tooltiptxt="Uses Context Shifting to reduce reprocessing.\nRecommended. Check the wiki for more info.", command=togglectxshift)
togglectxshift(1,1,1)
@ -2206,6 +2213,7 @@ def show_new_gui():
args.launch = launchbrowser.get()==1
args.highpriority = highpriority.get()==1
args.nommap = disablemmap.get()==1
args.smartcontext = smartcontext.get()==1
args.flashattention = flashattention.get()==1
args.noshift = contextshift.get()==0
args.remotetunnel = remotetunnel.get()==1
@ -2301,6 +2309,7 @@ def show_new_gui():
launchbrowser.set(1 if "launch" in dict and dict["launch"] else 0)
highpriority.set(1 if "highpriority" in dict and dict["highpriority"] else 0)
disablemmap.set(1 if "nommap" in dict and dict["nommap"] else 0)
smartcontext.set(1 if "smartcontext" in dict and dict["smartcontext"] else 0)
flashattention.set(1 if "flashattention" in dict and dict["flashattention"] else 0)
contextshift.set(0 if "noshift" in dict and dict["noshift"] else 1)
remotetunnel.set(1 if "remotetunnel" in dict and dict["remotetunnel"] else 0)
@ -2730,7 +2739,7 @@ def check_deprecation_warning():
if using_outdated_flags:
print(f"\n=== !!! IMPORTANT WARNING !!! ===")
print("You are using one or more OUTDATED config files or launch flags!")
print("The flags --smartcontext, --hordeconfig and --sdconfig have been DEPRECATED, and MAY be REMOVED in future!")
print("The flags --hordeconfig and --sdconfig have been DEPRECATED, and MAY be REMOVED in future!")
print("They will still work for now, but you SHOULD switch to the updated flags instead, to avoid future issues!")
print("New flags are: --hordemodelname --hordeworkername --hordekey --hordemaxctx --hordegenlen --sdmodel --sdthreads --sdquant --sdclamped")
print("For more information on these flags, please check --help")
@ -3393,11 +3402,7 @@ if __name__ == '__main__':
advparser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")
advparser.add_argument("--flashattention", help="Enables flash attention (Experimental).", action='store_true')
advparser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
deprecatedgroup.add_argument("--smartcontext", help="Command is DEPRECATED and should NOT be used! Instead, use --noshift instead to toggle smartcontext off on old GGML models.", action='store_true')
deprecatedgroup.add_argument("--hordeconfig", help="Command is DEPRECATED and should NOT be used! Instead, use non-positional flags --hordemodelname --hordeworkername --hordekey --hordemaxctx --hordegenlen instead.", nargs='+')
deprecatedgroup.add_argument("--sdconfig", help="Command is DEPRECATED and should NOT be used! Instead, use non-positional flags --sdmodel --sdthreads --sdquant --sdclamped instead.", nargs='+')
advparser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently. Not recommended.", action='store_true')
hordeparsergroup = parser.add_argument_group('Horde Worker Commands')
hordeparsergroup.add_argument("--hordemodelname", metavar=('[name]'), help="Sets your AI Horde display model name.", default="")
@ -3412,4 +3417,8 @@ if __name__ == '__main__':
sdparsergroup.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use.", action='store_true')
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+')
deprecatedgroup.add_argument("--sdconfig", help=argparse.SUPPRESS, nargs='+')
main(parser.parse_args(),start_server=True)