feat: Support Stable Code Instruct 3B (#552)

* feat: Support Stable Code Instruct 3B

* feat: Sort LLaMA models in settings
This commit is contained in:
Rene Leonhardt 2024-05-16 21:28:54 +02:00 committed by GitHub
parent 9705ab7511
commit 586cff421e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 156 additions and 5 deletions

View file

@ -1,5 +1,6 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.HuggingFaceModel.Model.SC3;
import static ee.carlrobert.codegpt.completions.llama.LlamaModel.getDownloadedMarker;
import static ee.carlrobert.codegpt.completions.llama.LlamaModel.getLlamaModelsPath;
import static java.lang.String.format;
@ -116,8 +117,34 @@ public enum HuggingFaceModel {
"CodeQwen1.5-7B-Chat.Q5_K_M.gguf", "RichardErkhov", 5.43),
CODE_QWEN_1_5_7B_Q6_K(7, 6, "Qwen_-_CodeQwen1.5-7B-Chat-gguf",
"CodeQwen1.5-7B-Chat.Q6_K.gguf", "RichardErkhov", 6.38),
STABLE_CODE_3B_Q3_K_M(SC3, 3, "stable-code-instruct-3b-Q3_K_M.gguf", 1.39),
STABLE_CODE_3B_Q4_K_M(SC3, 4, "stable-code-instruct-3b-Q4_K_M.gguf", 1.71),
STABLE_CODE_3B_Q5_K_M(SC3, 5, "stable-code-instruct-3b-Q5_K_M.gguf", 1.99),
STABLE_CODE_3B_Q6_K(SC3, 6, "stable-code-instruct-3b-Q6_K.gguf", 2.3),
STABLE_CODE_3B_Q8_0(SC3, 8, "stable-code-instruct-3b-Q8_0.gguf", 2.97),
;
enum Model {
SC3("bartowski", 3, "stable-code-instruct-3b-GGUF");
private final String user;
private final int parameterSize;
private final String directory;
private final String prefix;
Model(String user, int parameterSize, String directory) {
this(user, parameterSize, directory, null);
}
Model(String user, int parameterSize, String directory, String prefix) {
this.user = user;
this.parameterSize = parameterSize;
this.directory = directory;
this.prefix = prefix;
}
}
private final int parameterSize;
private final int quantization;
private final String directory;
@ -134,6 +161,10 @@ public enum HuggingFaceModel {
this(parameterSize, quantization, directory, fileName, "TheBloke", downloadSize);
}
HuggingFaceModel(Model m, int quantization, String fileName, Double downloadSize) {
this(m.parameterSize, quantization, m.directory, fileName, m.user, downloadSize);
}
HuggingFaceModel(int parameterSize, int quantization, String directory, String fileName,
String user, Double downloadSize) {
this.parameterSize = parameterSize;

View file

@ -151,6 +151,24 @@ public enum LlamaModel {
HuggingFaceModel.CODE_QWEN_1_5_7B_Q4_K_M,
HuggingFaceModel.CODE_QWEN_1_5_7B_Q5_K_M,
HuggingFaceModel.CODE_QWEN_1_5_7B_Q6_K)),
STABLE_CODE(
"Stable Code Instruct", """
stable-code-instruct-3b is a 2.7B billion parameter decoder-only language model tuned from \
stable-code-3b. This model was trained on a mix of publicly available datasets, synthetic \
datasets using Direct Preference Optimization (DPO).
This instruct tune demonstrates state-of-the-art performance (compared to models of similar \
size) on the MultiPL-E metrics across multiple programming languages tested using BigCode's \
Evaluation Harness, and on the code portions of MT Bench. The model is fine tuned to make it \
usable in tasks like general purpose Code/Software Engineering like conversations and \
SQL related generation and conversation.""",
PromptTemplate.STABLE_CODE,
InfillPromptTemplate.CODE_QWEN,
List.of(
HuggingFaceModel.STABLE_CODE_3B_Q3_K_M,
HuggingFaceModel.STABLE_CODE_3B_Q4_K_M,
HuggingFaceModel.STABLE_CODE_3B_Q5_K_M,
HuggingFaceModel.STABLE_CODE_3B_Q6_K,
HuggingFaceModel.STABLE_CODE_3B_Q8_0)),
;
private final String label;
@ -269,6 +287,10 @@ public enum LlamaModel {
.stream().toList();
}
public static List<LlamaModel> getSorted() {
return Arrays.stream(values()).sorted(Comparator.comparing(Enum::name)).toList();
}
public record ModelSize(int size, boolean downloaded) implements Comparable<ModelSize> {
// Sort by size, but downloaded comes first: [ 7B, 13B, 13B, 34B]
private static final Comparator<ModelSize> sizeDownloadedFirst = Comparator

View file

@ -186,6 +186,30 @@ public enum PromptTemplate {
.toString();
}
},
STABLE_CODE("Stable Code Instruct", List.of("<|endoftext|>", "<|im_end|>")) {
@Override
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
StringBuilder prompt = new StringBuilder();
if (systemPrompt != null && !systemPrompt.isBlank()) {
prompt.append("<|im_start|>system\n")
.append(systemPrompt)
.append("<|im_end|>\n");
}
for (Message message : history) {
prompt.append("<|im_start|>user\n")
.append(message.getPrompt())
.append("<|im_end|>\n<|im_start|>assistant\n")
.append(message.getResponse()).append("<|im_end|>\n");
}
return prompt.append("<|im_start|>user\n")
.append(userPrompt)
.append("<|im_end|>\n<|im_start|>assistant\n")
.toString();
}
},
ALPACA("Alpaca/Vicuna") {
@Override
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {

View file

@ -17,7 +17,6 @@ import com.intellij.openapi.ui.ComboBox;
import com.intellij.openapi.ui.TextBrowseFolderListener;
import com.intellij.openapi.ui.TextFieldWithBrowseButton;
import com.intellij.openapi.ui.panel.ComponentPanelBuilder;
import com.intellij.ui.EnumComboBoxModel;
import com.intellij.ui.components.AnActionLink;
import com.intellij.ui.components.JBLabel;
import com.intellij.ui.components.JBRadioButton;
@ -41,6 +40,7 @@ import java.util.Map;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.ButtonGroup;
import javax.swing.ComboBoxModel;
import javax.swing.DefaultComboBoxModel;
import javax.swing.DefaultListCellRenderer;
import javax.swing.JList;
@ -113,7 +113,8 @@ public class LlamaModelPreferencesForm {
var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class);
huggingFaceModelComboBox.setEnabled(!llamaServerAgent.isServerRunning());
var modelSizeComboBoxModel = new DefaultComboBoxModel<ModelSize>();
var modelComboBoxModel = new EnumComboBoxModel<>(LlamaModel.class);
var modelComboBoxModel = new DefaultComboBoxModel<LlamaModel>();
modelComboBoxModel.addAll(LlamaModel.getSorted());
modelComboBox = createModelComboBox(
modelComboBoxModel, llamaModel, llm, llamaServerAgent, modelSizeComboBoxModel);
modelComboBox.setEnabled(!llamaServerAgent.isServerRunning());
@ -302,7 +303,7 @@ public class LlamaModelPreferencesForm {
}
private ComboBox<LlamaModel> createModelComboBox(
EnumComboBoxModel<LlamaModel> llamaModelEnumComboBoxModel,
ComboBoxModel<LlamaModel> llamaModelEnumComboBoxModel,
LlamaModel llamaModel,
HuggingFaceModel llm,
LlamaServerAgent llamaServerAgent,
@ -340,7 +341,7 @@ public class LlamaModelPreferencesForm {
}
private ComboBox<ModelSize> createModelSizeComboBox(
EnumComboBoxModel<LlamaModel> llamaModelComboBoxModel,
ComboBoxModel<LlamaModel> llamaModelComboBoxModel,
DefaultComboBoxModel<ModelSize> modelSizeComboBoxModel,
LlamaServerAgent llamaServerAgent,
DefaultComboBoxModel<HuggingFaceModel> huggingFaceComboBoxModel) {
@ -350,7 +351,7 @@ public class LlamaModelPreferencesForm {
comboBox.setEnabled(
modelSizeComboBoxModel.getSize() > 1 && !llamaServerAgent.isServerRunning());
comboBox.addItemListener(e -> {
var selectedModel = llamaModelComboBoxModel.getSelectedItem();
var selectedModel = (LlamaModel) llamaModelComboBoxModel.getSelectedItem();
var models = selectedModel.filterSelectedModelsBySize(
(ModelSize) modelSizeComboBoxModel.getSelectedItem());
comboBox.setEnabled(

View file

@ -7,6 +7,7 @@ import ee.carlrobert.codegpt.completions.llama.PromptTemplate.CODE_QWEN
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA_3
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.PHI_3
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.STABLE_CODE
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA
import ee.carlrobert.codegpt.conversations.message.Message
import org.assertj.core.api.Assertions.assertThat
@ -309,6 +310,78 @@ class PromptTemplateTest {
""".trimIndent())
}
@Test
fun shouldBuildStableCodePromptWithoutHistory() {
val prompt = STABLE_CODE.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
assertThat(prompt).isEqualTo("""
<|im_start|>system
TEST_SYSTEM_PROMPT<|im_end|>
<|im_start|>user
TEST_USER_PROMPT<|im_end|>
<|im_start|>assistant
""".trimIndent())
}
@ParameterizedTest
@NullAndEmptySource
@ValueSource(strings = [" ", "\t", "\n"])
fun shouldBuildStableCodePromptWithoutHistorySkippingBlankSystemPrompt(systemPrompt: String?) {
val prompt = STABLE_CODE.buildPrompt(systemPrompt, USER_PROMPT, listOf())
assertThat(prompt).isEqualTo("""
<|im_start|>user
TEST_USER_PROMPT<|im_end|>
<|im_start|>assistant
""".trimIndent())
}
@Test
fun shouldBuildStableCodePromptWithHistory() {
val prompt = STABLE_CODE.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
assertThat(prompt).isEqualTo("""
<|im_start|>system
TEST_SYSTEM_PROMPT<|im_end|>
<|im_start|>user
TEST_PREV_PROMPT_1<|im_end|>
<|im_start|>assistant
TEST_PREV_RESPONSE_1<|im_end|>
<|im_start|>user
TEST_PREV_PROMPT_2<|im_end|>
<|im_start|>assistant
TEST_PREV_RESPONSE_2<|im_end|>
<|im_start|>user
TEST_USER_PROMPT<|im_end|>
<|im_start|>assistant
""".trimIndent())
}
@ParameterizedTest
@NullAndEmptySource
@ValueSource(strings = [" ", "\t", "\n"])
fun shouldBuildStableCodePromptWithHistorySkippingBlankSystemPrompt(systemPrompt: String?) {
val prompt = STABLE_CODE.buildPrompt(systemPrompt, USER_PROMPT, HISTORY)
assertThat(prompt).isEqualTo("""
<|im_start|>user
TEST_PREV_PROMPT_1<|im_end|>
<|im_start|>assistant
TEST_PREV_RESPONSE_1<|im_end|>
<|im_start|>user
TEST_PREV_PROMPT_2<|im_end|>
<|im_start|>assistant
TEST_PREV_RESPONSE_2<|im_end|>
<|im_start|>user
TEST_USER_PROMPT<|im_end|>
<|im_start|>assistant
""".trimIndent())
}
@Test
fun shouldBuildAlpacaPromptWithHistory() {
val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)