diff --git a/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java b/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java index 8238604d..4d6939b4 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java @@ -1,5 +1,6 @@ package ee.carlrobert.codegpt.completions; +import static ee.carlrobert.codegpt.completions.HuggingFaceModel.Model.SC3; import static ee.carlrobert.codegpt.completions.llama.LlamaModel.getDownloadedMarker; import static ee.carlrobert.codegpt.completions.llama.LlamaModel.getLlamaModelsPath; import static java.lang.String.format; @@ -116,8 +117,34 @@ public enum HuggingFaceModel { "CodeQwen1.5-7B-Chat.Q5_K_M.gguf", "RichardErkhov", 5.43), CODE_QWEN_1_5_7B_Q6_K(7, 6, "Qwen_-_CodeQwen1.5-7B-Chat-gguf", "CodeQwen1.5-7B-Chat.Q6_K.gguf", "RichardErkhov", 6.38), + + STABLE_CODE_3B_Q3_K_M(SC3, 3, "stable-code-instruct-3b-Q3_K_M.gguf", 1.39), + STABLE_CODE_3B_Q4_K_M(SC3, 4, "stable-code-instruct-3b-Q4_K_M.gguf", 1.71), + STABLE_CODE_3B_Q5_K_M(SC3, 5, "stable-code-instruct-3b-Q5_K_M.gguf", 1.99), + STABLE_CODE_3B_Q6_K(SC3, 6, "stable-code-instruct-3b-Q6_K.gguf", 2.3), + STABLE_CODE_3B_Q8_0(SC3, 8, "stable-code-instruct-3b-Q8_0.gguf", 2.97), ; + enum Model { + SC3("bartowski", 3, "stable-code-instruct-3b-GGUF"); + + private final String user; + private final int parameterSize; + private final String directory; + private final String prefix; + + Model(String user, int parameterSize, String directory) { + this(user, parameterSize, directory, null); + } + + Model(String user, int parameterSize, String directory, String prefix) { + this.user = user; + this.parameterSize = parameterSize; + this.directory = directory; + this.prefix = prefix; + } + } + private final int parameterSize; private final int quantization; private final String directory; @@ -134,6 +161,10 @@ public enum HuggingFaceModel { this(parameterSize, quantization, directory, fileName, "TheBloke", downloadSize); } + HuggingFaceModel(Model m, int quantization, String fileName, Double downloadSize) { + this(m.parameterSize, quantization, m.directory, fileName, m.user, downloadSize); + } + HuggingFaceModel(int parameterSize, int quantization, String directory, String fileName, String user, Double downloadSize) { this.parameterSize = parameterSize; diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java index 509d1319..6140aeaa 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java @@ -151,6 +151,24 @@ public enum LlamaModel { HuggingFaceModel.CODE_QWEN_1_5_7B_Q4_K_M, HuggingFaceModel.CODE_QWEN_1_5_7B_Q5_K_M, HuggingFaceModel.CODE_QWEN_1_5_7B_Q6_K)), + STABLE_CODE( + "Stable Code Instruct", """ + stable-code-instruct-3b is a 2.7B billion parameter decoder-only language model tuned from \ + stable-code-3b. This model was trained on a mix of publicly available datasets, synthetic \ + datasets using Direct Preference Optimization (DPO). + This instruct tune demonstrates state-of-the-art performance (compared to models of similar \ + size) on the MultiPL-E metrics across multiple programming languages tested using BigCode's \ + Evaluation Harness, and on the code portions of MT Bench. The model is fine tuned to make it \ + usable in tasks like general purpose Code/Software Engineering like conversations and \ + SQL related generation and conversation.""", + PromptTemplate.STABLE_CODE, + InfillPromptTemplate.CODE_QWEN, + List.of( + HuggingFaceModel.STABLE_CODE_3B_Q3_K_M, + HuggingFaceModel.STABLE_CODE_3B_Q4_K_M, + HuggingFaceModel.STABLE_CODE_3B_Q5_K_M, + HuggingFaceModel.STABLE_CODE_3B_Q6_K, + HuggingFaceModel.STABLE_CODE_3B_Q8_0)), ; private final String label; @@ -269,6 +287,10 @@ public enum LlamaModel { .stream().toList(); } + public static List getSorted() { + return Arrays.stream(values()).sorted(Comparator.comparing(Enum::name)).toList(); + } + public record ModelSize(int size, boolean downloaded) implements Comparable { // Sort by size, but downloaded comes first: [ 7B, ✓ 13B, 13B, 34B] private static final Comparator sizeDownloadedFirst = Comparator diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java index 9e7c161b..24fe7789 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java @@ -186,6 +186,30 @@ public enum PromptTemplate { .toString(); } }, + STABLE_CODE("Stable Code Instruct", List.of("<|endoftext|>", "<|im_end|>")) { + @Override + public String buildPrompt(String systemPrompt, String userPrompt, List history) { + StringBuilder prompt = new StringBuilder(); + + if (systemPrompt != null && !systemPrompt.isBlank()) { + prompt.append("<|im_start|>system\n") + .append(systemPrompt) + .append("<|im_end|>\n"); + } + + for (Message message : history) { + prompt.append("<|im_start|>user\n") + .append(message.getPrompt()) + .append("<|im_end|>\n<|im_start|>assistant\n") + .append(message.getResponse()).append("<|im_end|>\n"); + } + + return prompt.append("<|im_start|>user\n") + .append(userPrompt) + .append("<|im_end|>\n<|im_start|>assistant\n") + .toString(); + } + }, ALPACA("Alpaca/Vicuna") { @Override public String buildPrompt(String systemPrompt, String userPrompt, List history) { diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java index 4b270ea7..a888fd30 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java @@ -17,7 +17,6 @@ import com.intellij.openapi.ui.ComboBox; import com.intellij.openapi.ui.TextBrowseFolderListener; import com.intellij.openapi.ui.TextFieldWithBrowseButton; import com.intellij.openapi.ui.panel.ComponentPanelBuilder; -import com.intellij.ui.EnumComboBoxModel; import com.intellij.ui.components.AnActionLink; import com.intellij.ui.components.JBLabel; import com.intellij.ui.components.JBRadioButton; @@ -41,6 +40,7 @@ import java.util.Map; import javax.swing.Box; import javax.swing.BoxLayout; import javax.swing.ButtonGroup; +import javax.swing.ComboBoxModel; import javax.swing.DefaultComboBoxModel; import javax.swing.DefaultListCellRenderer; import javax.swing.JList; @@ -113,7 +113,8 @@ public class LlamaModelPreferencesForm { var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class); huggingFaceModelComboBox.setEnabled(!llamaServerAgent.isServerRunning()); var modelSizeComboBoxModel = new DefaultComboBoxModel(); - var modelComboBoxModel = new EnumComboBoxModel<>(LlamaModel.class); + var modelComboBoxModel = new DefaultComboBoxModel(); + modelComboBoxModel.addAll(LlamaModel.getSorted()); modelComboBox = createModelComboBox( modelComboBoxModel, llamaModel, llm, llamaServerAgent, modelSizeComboBoxModel); modelComboBox.setEnabled(!llamaServerAgent.isServerRunning()); @@ -302,7 +303,7 @@ public class LlamaModelPreferencesForm { } private ComboBox createModelComboBox( - EnumComboBoxModel llamaModelEnumComboBoxModel, + ComboBoxModel llamaModelEnumComboBoxModel, LlamaModel llamaModel, HuggingFaceModel llm, LlamaServerAgent llamaServerAgent, @@ -340,7 +341,7 @@ public class LlamaModelPreferencesForm { } private ComboBox createModelSizeComboBox( - EnumComboBoxModel llamaModelComboBoxModel, + ComboBoxModel llamaModelComboBoxModel, DefaultComboBoxModel modelSizeComboBoxModel, LlamaServerAgent llamaServerAgent, DefaultComboBoxModel huggingFaceComboBoxModel) { @@ -350,7 +351,7 @@ public class LlamaModelPreferencesForm { comboBox.setEnabled( modelSizeComboBoxModel.getSize() > 1 && !llamaServerAgent.isServerRunning()); comboBox.addItemListener(e -> { - var selectedModel = llamaModelComboBoxModel.getSelectedItem(); + var selectedModel = (LlamaModel) llamaModelComboBoxModel.getSelectedItem(); var models = selectedModel.filterSelectedModelsBySize( (ModelSize) modelSizeComboBoxModel.getSelectedItem()); comboBox.setEnabled( diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt index c78286b0..1669c247 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt @@ -7,6 +7,7 @@ import ee.carlrobert.codegpt.completions.llama.PromptTemplate.CODE_QWEN import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA_3 import ee.carlrobert.codegpt.completions.llama.PromptTemplate.PHI_3 +import ee.carlrobert.codegpt.completions.llama.PromptTemplate.STABLE_CODE import ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA import ee.carlrobert.codegpt.conversations.message.Message import org.assertj.core.api.Assertions.assertThat @@ -309,6 +310,78 @@ class PromptTemplateTest { """.trimIndent()) } + @Test + fun shouldBuildStableCodePromptWithoutHistory() { + val prompt = STABLE_CODE.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf()) + + assertThat(prompt).isEqualTo(""" + <|im_start|>system + TEST_SYSTEM_PROMPT<|im_end|> + <|im_start|>user + TEST_USER_PROMPT<|im_end|> + <|im_start|>assistant + + """.trimIndent()) + } + + @ParameterizedTest + @NullAndEmptySource + @ValueSource(strings = [" ", "\t", "\n"]) + fun shouldBuildStableCodePromptWithoutHistorySkippingBlankSystemPrompt(systemPrompt: String?) { + val prompt = STABLE_CODE.buildPrompt(systemPrompt, USER_PROMPT, listOf()) + + assertThat(prompt).isEqualTo(""" + <|im_start|>user + TEST_USER_PROMPT<|im_end|> + <|im_start|>assistant + + """.trimIndent()) + } + + @Test + fun shouldBuildStableCodePromptWithHistory() { + val prompt = STABLE_CODE.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY) + + assertThat(prompt).isEqualTo(""" + <|im_start|>system + TEST_SYSTEM_PROMPT<|im_end|> + <|im_start|>user + TEST_PREV_PROMPT_1<|im_end|> + <|im_start|>assistant + TEST_PREV_RESPONSE_1<|im_end|> + <|im_start|>user + TEST_PREV_PROMPT_2<|im_end|> + <|im_start|>assistant + TEST_PREV_RESPONSE_2<|im_end|> + <|im_start|>user + TEST_USER_PROMPT<|im_end|> + <|im_start|>assistant + + """.trimIndent()) + } + + @ParameterizedTest + @NullAndEmptySource + @ValueSource(strings = [" ", "\t", "\n"]) + fun shouldBuildStableCodePromptWithHistorySkippingBlankSystemPrompt(systemPrompt: String?) { + val prompt = STABLE_CODE.buildPrompt(systemPrompt, USER_PROMPT, HISTORY) + + assertThat(prompt).isEqualTo(""" + <|im_start|>user + TEST_PREV_PROMPT_1<|im_end|> + <|im_start|>assistant + TEST_PREV_RESPONSE_1<|im_end|> + <|im_start|>user + TEST_PREV_PROMPT_2<|im_end|> + <|im_start|>assistant + TEST_PREV_RESPONSE_2<|im_end|> + <|im_start|>user + TEST_USER_PROMPT<|im_end|> + <|im_start|>assistant + + """.trimIndent()) + } + @Test fun shouldBuildAlpacaPromptWithHistory() { val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)