From 7899429d4fd5b2364746d124b2c14800dd248cb8 Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Sun, 21 Apr 2024 23:01:33 +0300 Subject: [PATCH] fix: llama3 prompt --- .../CompletionRequestProvider.java | 1 + .../completions/llama/PromptTemplate.java | 70 ++++++++++--------- .../codegpt/completions/PromptTemplateTest.kt | 16 ++--- 3 files changed, 46 insertions(+), 41 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index 160d7189..58f28c3a 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -180,6 +180,7 @@ public class CompletionRequestProvider { .setTop_p(settings.getTopP()) .setMin_p(settings.getMinP()) .setRepeat_penalty(settings.getRepeatPenalty()) + .setStop(promptTemplate.getStopTokens()) .build(); } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java index 584539db..7101587a 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java @@ -1,11 +1,9 @@ package ee.carlrobert.codegpt.completions.llama; -import static java.util.stream.Stream.concat; +import static java.util.Collections.emptyList; import ee.carlrobert.codegpt.conversations.message.Message; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; public enum PromptTemplate { @@ -59,24 +57,31 @@ public enum PromptTemplate { .toString(); } }, - LLAMA_3("Llama 3") { + LLAMA_3("Llama 3", List.of("<|eot_id|>")) { @Override public String buildPrompt(String systemPrompt, String userPrompt, List history) { - return concat(concat(Stream.ofNullable(systemPrompt) - .filter(s -> !s.isBlank()) - .flatMap(system -> Stream.of( - "<|start_header_id|>system<|end_header_id|>\n\n", - system, - "<|eot_id|>")), - history.stream().flatMap(message -> mapMessage( - message, - "<|start_header_id|>user<|end_header_id|>\n\n", - "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - "<|eot_id|>"))), Stream.of( - "<|start_header_id|>user<|end_header_id|>\n\n", - userPrompt, - "<|eot_id|>")) - .collect(Collectors.joining()); + var prompt = new StringBuilder("<|begin_of_text|>"); + if (systemPrompt != null && !systemPrompt.isBlank()) { + prompt + .append("<|start_header_id|>system<|end_header_id|>\n\n") + .append(systemPrompt) + .append("<|eot_id|>"); + } + + for (var message : history) { + prompt + .append("<|start_header_id|>user<|end_header_id|>\n\n") + .append(message.getPrompt()) + .append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n") + .append(message.getResponse()) + .append("<|eot_id|>"); + } + + return prompt + .append("<|start_header_id|>user<|end_header_id|>\n\n") + .append(userPrompt) + .append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>") + .toString(); } }, MIXTRAL_INSTRUCT("Mixtral Instruct") { @@ -126,10 +131,10 @@ public enum PromptTemplate { StringBuilder prompt = new StringBuilder(); prompt.append(""" - Below is an instruction that describes a task. \ - Write a response that appropriately completes the request. + Below is an instruction that describes a task. \ + Write a response that appropriately completes the request. - """); + """); for (Message message : history) { prompt.append("### Instruction\n") @@ -184,26 +189,25 @@ public enum PromptTemplate { }; private final String label; + private final List stopTokens; PromptTemplate(String label) { + this(label, emptyList()); + } + + PromptTemplate(String label, List stopTokens) { this.label = label; + this.stopTokens = stopTokens; } public abstract String buildPrompt(String systemPrompt, String userPrompt, List history); + public List getStopTokens() { + return stopTokens; + } + @Override public String toString() { return label; } - - private static Stream mapMessage(Message message, - String prefix, String infix, String suffix) { - return Stream.of( - prefix, - message.getPrompt(), - infix, - message.getResponse(), - suffix - ); - } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt index 5e54468b..84dad146 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt @@ -43,11 +43,11 @@ class PromptTemplateTest { val prompt = LLAMA_3.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf()) assertThat(prompt).isEqualTo(""" - <|start_header_id|>system<|end_header_id|> + <|begin_of_text|><|start_header_id|>system<|end_header_id|> TEST_SYSTEM_PROMPT<|eot_id|><|start_header_id|>user<|end_header_id|> - TEST_USER_PROMPT<|eot_id|>""".trimIndent() + TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent() ) } @@ -58,9 +58,9 @@ class PromptTemplateTest { val prompt = LLAMA_3.buildPrompt(systemPrompt, USER_PROMPT, listOf()) assertThat(prompt).isEqualTo(""" - <|start_header_id|>user<|end_header_id|> + <|begin_of_text|><|start_header_id|>user<|end_header_id|> - TEST_USER_PROMPT<|eot_id|>""".trimIndent() + TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent() ) } @@ -69,7 +69,7 @@ class PromptTemplateTest { val prompt = LLAMA_3.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY) assertThat(prompt).isEqualTo(""" - <|start_header_id|>system<|end_header_id|> + <|begin_of_text|><|start_header_id|>system<|end_header_id|> TEST_SYSTEM_PROMPT<|eot_id|><|start_header_id|>user<|end_header_id|> @@ -81,7 +81,7 @@ class PromptTemplateTest { TEST_PREV_RESPONSE_2<|eot_id|><|start_header_id|>user<|end_header_id|> - TEST_USER_PROMPT<|eot_id|>""".trimIndent()) + TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent()) } @ParameterizedTest @@ -91,7 +91,7 @@ class PromptTemplateTest { val prompt = LLAMA_3.buildPrompt(systemPrompt, USER_PROMPT, HISTORY) assertThat(prompt).isEqualTo(""" - <|start_header_id|>user<|end_header_id|> + <|begin_of_text|><|start_header_id|>user<|end_header_id|> TEST_PREV_PROMPT_1<|eot_id|><|start_header_id|>assistant<|end_header_id|> @@ -101,7 +101,7 @@ class PromptTemplateTest { TEST_PREV_RESPONSE_2<|eot_id|><|start_header_id|>user<|end_header_id|> - TEST_USER_PROMPT<|eot_id|>""".trimIndent()) + TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent()) } @Test