fix: llama3 prompt

This commit is contained in:
Carl-Robert Linnupuu 2024-04-21 23:01:33 +03:00
parent 62f0fa43bc
commit 7899429d4f
3 changed files with 46 additions and 41 deletions

View file

@ -180,6 +180,7 @@ public class CompletionRequestProvider {
.setTop_p(settings.getTopP())
.setMin_p(settings.getMinP())
.setRepeat_penalty(settings.getRepeatPenalty())
.setStop(promptTemplate.getStopTokens())
.build();
}

View file

@ -1,11 +1,9 @@
package ee.carlrobert.codegpt.completions.llama;
import static java.util.stream.Stream.concat;
import static java.util.Collections.emptyList;
import ee.carlrobert.codegpt.conversations.message.Message;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public enum PromptTemplate {
@ -59,24 +57,31 @@ public enum PromptTemplate {
.toString();
}
},
LLAMA_3("Llama 3") {
LLAMA_3("Llama 3", List.of("<|eot_id|>")) {
@Override
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
return concat(concat(Stream.ofNullable(systemPrompt)
.filter(s -> !s.isBlank())
.flatMap(system -> Stream.of(
"<|start_header_id|>system<|end_header_id|>\n\n",
system,
"<|eot_id|>")),
history.stream().flatMap(message -> mapMessage(
message,
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>"))), Stream.of(
"<|start_header_id|>user<|end_header_id|>\n\n",
userPrompt,
"<|eot_id|>"))
.collect(Collectors.joining());
var prompt = new StringBuilder("<|begin_of_text|>");
if (systemPrompt != null && !systemPrompt.isBlank()) {
prompt
.append("<|start_header_id|>system<|end_header_id|>\n\n")
.append(systemPrompt)
.append("<|eot_id|>");
}
for (var message : history) {
prompt
.append("<|start_header_id|>user<|end_header_id|>\n\n")
.append(message.getPrompt())
.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")
.append(message.getResponse())
.append("<|eot_id|>");
}
return prompt
.append("<|start_header_id|>user<|end_header_id|>\n\n")
.append(userPrompt)
.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
.toString();
}
},
MIXTRAL_INSTRUCT("Mixtral Instruct") {
@ -126,10 +131,10 @@ public enum PromptTemplate {
StringBuilder prompt = new StringBuilder();
prompt.append("""
Below is an instruction that describes a task. \
Write a response that appropriately completes the request.
Below is an instruction that describes a task. \
Write a response that appropriately completes the request.
""");
""");
for (Message message : history) {
prompt.append("### Instruction\n")
@ -184,26 +189,25 @@ public enum PromptTemplate {
};
private final String label;
private final List<String> stopTokens;
PromptTemplate(String label) {
this(label, emptyList());
}
PromptTemplate(String label, List<String> stopTokens) {
this.label = label;
this.stopTokens = stopTokens;
}
public abstract String buildPrompt(String systemPrompt, String userPrompt, List<Message> history);
public List<String> getStopTokens() {
return stopTokens;
}
@Override
public String toString() {
return label;
}
private static Stream<String> mapMessage(Message message,
String prefix, String infix, String suffix) {
return Stream.of(
prefix,
message.getPrompt(),
infix,
message.getResponse(),
suffix
);
}
}