fixed a bug with stop seq processing

This commit is contained in:
Concedo 2024-01-31 15:16:08 +08:00
parent 51fe7ac215
commit 916780eaf4

View file

@ -336,6 +336,7 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
inputs.memory = memory.encode("UTF-8")
if max_length >= max_context_length:
max_length = max_context_length-1
print("\nWARNING: You are trying to generate with max_length near or exceeding max_context_length. Most of the context will be gone and your outputs will not be very coherent.")
global showmaxctxwarning
if max_context_length > maxctx:
if showmaxctxwarning:
@ -382,6 +383,8 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
for n in range(stop_token_max):
if not stop_sequence or n >= len(stop_sequence):
inputs.stop_sequence[n] = "".encode("UTF-8")
elif stop_sequence[n]==None:
inputs.stop_sequence[n] = "".encode("UTF-8")
else:
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")