mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-13 15:32:25 +00:00
fix: llama3 prompt
This commit is contained in:
parent
62f0fa43bc
commit
7899429d4f
3 changed files with 46 additions and 41 deletions
|
|
@ -180,6 +180,7 @@ public class CompletionRequestProvider {
|
|||
.setTop_p(settings.getTopP())
|
||||
.setMin_p(settings.getMinP())
|
||||
.setRepeat_penalty(settings.getRepeatPenalty())
|
||||
.setStop(promptTemplate.getStopTokens())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
package ee.carlrobert.codegpt.completions.llama;
|
||||
|
||||
import static java.util.stream.Stream.concat;
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public enum PromptTemplate {
|
||||
|
||||
|
|
@ -59,24 +57,31 @@ public enum PromptTemplate {
|
|||
.toString();
|
||||
}
|
||||
},
|
||||
LLAMA_3("Llama 3") {
|
||||
LLAMA_3("Llama 3", List.of("<|eot_id|>")) {
|
||||
@Override
|
||||
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
|
||||
return concat(concat(Stream.ofNullable(systemPrompt)
|
||||
.filter(s -> !s.isBlank())
|
||||
.flatMap(system -> Stream.of(
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n",
|
||||
system,
|
||||
"<|eot_id|>")),
|
||||
history.stream().flatMap(message -> mapMessage(
|
||||
message,
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"<|eot_id|>"))), Stream.of(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
userPrompt,
|
||||
"<|eot_id|>"))
|
||||
.collect(Collectors.joining());
|
||||
var prompt = new StringBuilder("<|begin_of_text|>");
|
||||
if (systemPrompt != null && !systemPrompt.isBlank()) {
|
||||
prompt
|
||||
.append("<|start_header_id|>system<|end_header_id|>\n\n")
|
||||
.append(systemPrompt)
|
||||
.append("<|eot_id|>");
|
||||
}
|
||||
|
||||
for (var message : history) {
|
||||
prompt
|
||||
.append("<|start_header_id|>user<|end_header_id|>\n\n")
|
||||
.append(message.getPrompt())
|
||||
.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")
|
||||
.append(message.getResponse())
|
||||
.append("<|eot_id|>");
|
||||
}
|
||||
|
||||
return prompt
|
||||
.append("<|start_header_id|>user<|end_header_id|>\n\n")
|
||||
.append(userPrompt)
|
||||
.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
|
||||
.toString();
|
||||
}
|
||||
},
|
||||
MIXTRAL_INSTRUCT("Mixtral Instruct") {
|
||||
|
|
@ -126,10 +131,10 @@ public enum PromptTemplate {
|
|||
StringBuilder prompt = new StringBuilder();
|
||||
|
||||
prompt.append("""
|
||||
Below is an instruction that describes a task. \
|
||||
Write a response that appropriately completes the request.
|
||||
Below is an instruction that describes a task. \
|
||||
Write a response that appropriately completes the request.
|
||||
|
||||
""");
|
||||
""");
|
||||
|
||||
for (Message message : history) {
|
||||
prompt.append("### Instruction\n")
|
||||
|
|
@ -184,26 +189,25 @@ public enum PromptTemplate {
|
|||
};
|
||||
|
||||
private final String label;
|
||||
private final List<String> stopTokens;
|
||||
|
||||
PromptTemplate(String label) {
|
||||
this(label, emptyList());
|
||||
}
|
||||
|
||||
PromptTemplate(String label, List<String> stopTokens) {
|
||||
this.label = label;
|
||||
this.stopTokens = stopTokens;
|
||||
}
|
||||
|
||||
public abstract String buildPrompt(String systemPrompt, String userPrompt, List<Message> history);
|
||||
|
||||
public List<String> getStopTokens() {
|
||||
return stopTokens;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return label;
|
||||
}
|
||||
|
||||
private static Stream<String> mapMessage(Message message,
|
||||
String prefix, String infix, String suffix) {
|
||||
return Stream.of(
|
||||
prefix,
|
||||
message.getPrompt(),
|
||||
infix,
|
||||
message.getResponse(),
|
||||
suffix
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,11 +43,11 @@ class PromptTemplateTest {
|
|||
val prompt = LLAMA_3.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
TEST_SYSTEM_PROMPT<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
TEST_USER_PROMPT<|eot_id|>""".trimIndent()
|
||||
TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent()
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -58,9 +58,9 @@ class PromptTemplateTest {
|
|||
val prompt = LLAMA_3.buildPrompt(systemPrompt, USER_PROMPT, listOf())
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|start_header_id|>user<|end_header_id|>
|
||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
TEST_USER_PROMPT<|eot_id|>""".trimIndent()
|
||||
TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent()
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ class PromptTemplateTest {
|
|||
val prompt = LLAMA_3.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
TEST_SYSTEM_PROMPT<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ class PromptTemplateTest {
|
|||
|
||||
TEST_PREV_RESPONSE_2<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
TEST_USER_PROMPT<|eot_id|>""".trimIndent())
|
||||
TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent())
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
|
@ -91,7 +91,7 @@ class PromptTemplateTest {
|
|||
val prompt = LLAMA_3.buildPrompt(systemPrompt, USER_PROMPT, HISTORY)
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|start_header_id|>user<|end_header_id|>
|
||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
TEST_PREV_PROMPT_1<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class PromptTemplateTest {
|
|||
|
||||
TEST_PREV_RESPONSE_2<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
TEST_USER_PROMPT<|eot_id|>""".trimIndent())
|
||||
TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue