fixed race condition when generating

This commit is contained in:
Concedo 2024-08-20 20:17:55 +08:00
parent 7ee359a59b
commit c1ae350e5b
2 changed files with 6 additions and 9 deletions

View file

@ -1941,6 +1941,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
llama_reset_timings(llama_ctx_v4); llama_reset_timings(llama_ctx_v4);
} }
generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens
concat_output_mtx.lock(); concat_output_mtx.lock();
concat_output = ""; concat_output = "";
concat_output_reader_copy_poll = ""; concat_output_reader_copy_poll = "";
@ -2140,8 +2143,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1; bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1;
generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens
std::string grammarstr = inputs.grammar; std::string grammarstr = inputs.grammar;
bool grammar_retain_state = inputs.grammar_retain_state; bool grammar_retain_state = inputs.grammar_retain_state;

View file

@ -41,7 +41,7 @@ maxhordelen = 400
modelbusy = threading.Lock() modelbusy = threading.Lock()
requestsinqueue = 0 requestsinqueue = 0
defaultport = 5001 defaultport = 5001
KcppVersion = "1.73" KcppVersion = "1.73.1"
showdebug = True showdebug = True
guimode = False guimode = False
showsamplerwarning = True showsamplerwarning = True
@ -1412,11 +1412,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
global last_non_horde_req_time global last_non_horde_req_time
last_non_horde_req_time = time.time() last_non_horde_req_time = time.time()
return generate( return generate(genparams=genparams,is_quiet=is_quiet,stream_flag=stream_flag)
genparams=genparams,
is_quiet=is_quiet,
stream_flag=stream_flag
)
genout = {"text": "", "status": -1, "stopreason": -1} genout = {"text": "", "status": -1, "stopreason": -1}
if stream_flag: if stream_flag:
@ -1486,7 +1482,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
current_token = 0 current_token = 0
incomplete_token_buffer = bytearray() incomplete_token_buffer = bytearray()
async_sleep_short = 0.02 async_sleep_short = 0.02
await asyncio.sleep(0.3) #anti race condition, prevent check from overtaking generate await asyncio.sleep(0.5) #anti race condition, prevent check from overtaking generate
try: try:
tokenReserve = "" #keeps fully formed tokens that we cannot send out yet tokenReserve = "" #keeps fully formed tokens that we cannot send out yet
while True: while True: