feat: extract llama request settings to its own state, improve UI/UX

This commit is contained in:
Carl-Robert Linnupuu 2023-12-21 14:46:45 +02:00
parent 9d83107dd5
commit e230640063
15 changed files with 446 additions and 267 deletions

View file

@ -188,7 +188,7 @@ public class IncludeFilesInContextAction extends AnAction {
new Dimension(480, component.getPreferredSize().height + 48));
dialogBuilder.setNorthPanel(FormBuilder.createFormBuilder()
.addLabeledComponent(
CodeGPTBundle.get("action.includeFilesInContext.dialog.promptTemplate.label"),
CodeGPTBundle.get("shared.promptTemplate"),
PanelFactory.panel(promptTemplateTextArea).withComment(
"<html><p>The template that will be used to create the final prompt. "
+ "The <strong>{REPEATABLE_CONTEXT}</strong> placeholder must be included "

View file

@ -109,10 +109,10 @@ public class CompletionRequestProvider {
return new LlamaCompletionRequest.Builder(prompt)
.setN_predict(configuration.getMaxTokens())
.setTemperature(configuration.getTemperature())
.setTop_k(configuration.getTopK())
.setTop_p(configuration.getTopP())
.setMin_p(configuration.getMinP())
.setRepeat_penalty(configuration.getRepeatPenalty())
.setTop_k(settings.getTopK())
.setTop_p(settings.getTopP())
.setMin_p(settings.getMinP())
.setRepeat_penalty(settings.getRepeatPenalty())
.build();
}

View file

@ -10,7 +10,6 @@ import com.intellij.execution.process.ProcessAdapter;
import com.intellij.execution.process.ProcessEvent;
import com.intellij.execution.process.ProcessListener;
import com.intellij.execution.process.ProcessOutputType;
import com.intellij.notification.NotificationType;
import com.intellij.openapi.Disposable;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
@ -20,7 +19,6 @@ import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.settings.service.ServerProgressPanel;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@ -111,9 +109,7 @@ public final class LlamaServerAgent implements Disposable {
if (errorLines.isEmpty()) {
LOG.info(format("Server terminated with code %d", event.getExitCode()));
} else {
var error = String.join("", errorLines);
OverlayUtil.showNotification(error, NotificationType.ERROR);
LOG.error(error);
LOG.info(String.join("", errorLines));
}
onServerTerminated.run();

View file

@ -161,7 +161,7 @@ public class AdvancedSettingsComponent {
false);
var proxyPortPanel = UIUtil.createPanel(
proxyPortField,
CodeGPTBundle.get("advancedSettingsConfigurable.proxy.portField.label"),
CodeGPTBundle.get("shared.port"),
false);
UIUtil.setEqualLabelWidths(proxyTypePanel, proxyHostPanel);
UIUtil.setEqualLabelWidths(proxyPortPanel, proxyHostPanel);

View file

@ -50,10 +50,6 @@ public class ConfigurationComponent {
private final JTextArea commitMessagePromptTextArea;
private final IntegerField maxTokensField;
private final JBTextField temperatureField;
private final IntegerField topKField;
private final JBTextField topPField;
private final JBTextField minPField;
private final JBTextField repeatPenaltyField;
public ConfigurationComponent(Disposable parentDisposable, ConfigurationState configuration) {
table = new JBTable(new DefaultTableModel(
@ -72,19 +68,6 @@ public class ConfigurationComponent {
temperatureField = new JBTextField(12);
temperatureField.setText(String.valueOf(configuration.getTemperature()));
topKField = new IntegerField();
topKField.setColumns(12);
topKField.setValue(configuration.getTopK());
topPField = new JBTextField(12);
topPField.setText(String.valueOf(configuration.getTopP()));
minPField = new JBTextField(12);
minPField.setText(String.valueOf(configuration.getMinP()));
repeatPenaltyField = new JBTextField(12);
repeatPenaltyField.setText(String.valueOf(configuration.getRepeatPenalty()));
var temperatureFieldValidator = createInputValidator(parentDisposable, temperatureField);
temperatureField.getDocument().addDocumentListener(new DocumentListener() {
@Override
@ -148,9 +131,6 @@ public class ConfigurationComponent {
CodeGPTBundle.get("configurationConfigurable.section.assistant.title")))
.addComponent(createAssistantConfigurationForm())
.addComponentFillVertically(new JPanel(), 0)
.addComponent(new TitledSeparator(
CodeGPTBundle.get("configurationConfigurable.section.assistant.llamacppParams.title")))
.addComponent(createLlamaAssistantConfigurationForm())
.addComponent(new TitledSeparator(
CodeGPTBundle.get("configurationConfigurable.section.commitMessage.title")))
.addComponent(createCommitMessageConfigurationForm())
@ -230,34 +210,6 @@ public class ConfigurationComponent {
return form;
}
private JPanel createLlamaAssistantConfigurationForm() {
var formBuilder = FormBuilder.createFormBuilder();
addAssistantFormLabeledComponent(
formBuilder,
"configurationConfigurable.section.assistant.topKField.label",
"configurationConfigurable.section.assistant.topKField.comment",
topKField);
addAssistantFormLabeledComponent(
formBuilder,
"configurationConfigurable.section.assistant.topPField.label",
"configurationConfigurable.section.assistant.topPField.comment",
topPField);
addAssistantFormLabeledComponent(
formBuilder,
"configurationConfigurable.section.assistant.minPField.label",
"configurationConfigurable.section.assistant.minPField.comment",
minPField);
addAssistantFormLabeledComponent(
formBuilder,
"configurationConfigurable.section.assistant.repeatPenaltyField.label",
"configurationConfigurable.section.assistant.repeatPenaltyField.comment",
repeatPenaltyField);
var form = formBuilder.getPanel();
form.setBorder(JBUI.Borders.emptyLeft(16));
return form;
}
private JPanel createCommitMessageConfigurationForm() {
var formBuilder = FormBuilder.createFormBuilder();
addAssistantFormLabeledComponent(
@ -344,38 +296,6 @@ public class ConfigurationComponent {
maxTokensField.setValue(maxTokens);
}
public int getTopK() {
return topKField.getValue();
}
public void setTopK(int topK) {
topKField.setValue(topK);
}
public double getTopP() {
return Double.parseDouble(topPField.getText());
}
public void setTopP(double topP) {
topPField.setText(String.valueOf(topP));
}
public double getMinP() {
return Double.parseDouble(minPField.getText());
}
public void setMinP(double minP) {
minPField.setText(String.valueOf(minP));
}
public double getRepeatPenalty() {
return Double.parseDouble(repeatPenaltyField.getText());
}
public void setRepeatPenalty(double repeatPenalty) {
repeatPenaltyField.setText(String.valueOf(repeatPenalty));
}
public boolean isCheckForPluginUpdates() {
return checkForPluginUpdatesCheckBox.isSelected();
}

View file

@ -36,10 +36,6 @@ public class ConfigurationConfigurable implements Configurable {
return !configurationComponent.getTableData().equals(configuration.getTableData())
|| configurationComponent.getMaxTokens() != configuration.getMaxTokens()
|| configurationComponent.getTemperature() != configuration.getTemperature()
|| configurationComponent.getTopK() != configuration.getTopK()
|| configurationComponent.getTopP() != configuration.getTopP()
|| configurationComponent.getMinP() != configuration.getMinP()
|| configurationComponent.getRepeatPenalty() != configuration.getRepeatPenalty()
|| !configurationComponent.getSystemPrompt().equals(configuration.getSystemPrompt())
|| !configurationComponent.getCommitMessagePrompt()
.equals(configuration.getCommitMessagePrompt())
@ -59,10 +55,6 @@ public class ConfigurationConfigurable implements Configurable {
configuration.setTableData(configurationComponent.getTableData());
configuration.setMaxTokens(configurationComponent.getMaxTokens());
configuration.setTemperature(configurationComponent.getTemperature());
configuration.setTopK(configurationComponent.getTopK());
configuration.setTopP(configurationComponent.getTopP());
configuration.setMinP(configurationComponent.getMinP());
configuration.setRepeatPenalty(configurationComponent.getRepeatPenalty());
configuration.setSystemPrompt(configurationComponent.getSystemPrompt());
configuration.setCommitMessagePrompt(configurationComponent.getCommitMessagePrompt());
configuration.setCheckForPluginUpdates(configurationComponent.isCheckForPluginUpdates());
@ -80,10 +72,6 @@ public class ConfigurationConfigurable implements Configurable {
configurationComponent.setTableData(configuration.getTableData());
configurationComponent.setMaxTokens(configuration.getMaxTokens());
configurationComponent.setTemperature(configuration.getTemperature());
configurationComponent.setTopK(configuration.getTopK());
configurationComponent.setTopP(configuration.getTopP());
configurationComponent.setMinP(configuration.getMinP());
configurationComponent.setRepeatPenalty(configuration.getRepeatPenalty());
configurationComponent.setSystemPrompt(configuration.getSystemPrompt());
configurationComponent.setCommitMessagePrompt(configuration.getCommitMessagePrompt());
configurationComponent.setCheckForPluginUpdates(configuration.isCheckForPluginUpdates());

View file

@ -22,10 +22,6 @@ public class ConfigurationState implements PersistentStateComponent<Configuratio
private String commitMessagePrompt = COMPLETION_COMMIT_MESSAGE_PROMPT;
private int maxTokens = 1000;
private double temperature = 0.1;
private int topK = 40;
private double topP = 0.9;
private double minP = 0.05;
private double repeatPenalty = 1.1;
private boolean checkForPluginUpdates = true;
private boolean createNewChatOnEachAction;
private boolean ignoreGitCommitTokenLimit;
@ -80,38 +76,6 @@ public class ConfigurationState implements PersistentStateComponent<Configuratio
this.temperature = temperature;
}
public int getTopK() {
return topK;
}
public void setTopK(int topK) {
this.topK = topK;
}
public double getTopP() {
return topP;
}
public void setTopP(double topP) {
this.topP = topP;
}
public double getMinP() {
return minP;
}
public void setMinP(double minP) {
this.minP = minP;
}
public double getRepeatPenalty() {
return repeatPenalty;
}
public void setRepeatPenalty(double repeatPenalty) {
this.repeatPenalty = repeatPenalty;
}
public boolean isCreateNewChatOnEachAction() {
return createNewChatOnEachAction;
}

View file

@ -17,8 +17,8 @@ import com.intellij.openapi.ui.panel.ComponentPanelBuilder;
import com.intellij.openapi.util.io.FileUtil;
import com.intellij.ui.EnumComboBoxModel;
import com.intellij.ui.components.AnActionLink;
import com.intellij.ui.components.JBCheckBox;
import com.intellij.ui.components.JBLabel;
import com.intellij.ui.components.JBRadioButton;
import com.intellij.util.ui.FormBuilder;
import com.intellij.util.ui.JBUI;
import ee.carlrobert.codegpt.CodeGPTBundle;
@ -27,20 +27,30 @@ import ee.carlrobert.codegpt.completions.HuggingFaceModel;
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
import ee.carlrobert.codegpt.completions.llama.LlamaServerAgent;
import ee.carlrobert.codegpt.completions.llama.PromptTemplate;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import java.awt.BorderLayout;
import java.awt.CardLayout;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.ButtonGroup;
import javax.swing.DefaultComboBoxModel;
import javax.swing.JPanel;
import javax.swing.SwingUtilities;
import org.apache.commons.text.StringEscapeUtils;
import org.jetbrains.annotations.NotNull;
public class LlamaModelPreferencesForm {
private static final String PREDEFINED_MODEL_FORM_CARD_CODE = "PredefinedModelSettings";
private static final String CUSTOM_MODEL_FORM_CARD_CODE = "CustomModelSettings";
private static final Map<Integer, Map<Integer, ModelDetails>> modelDetailsMap = Map.of(
7, Map.of(
3, new ModelDetails(3.30, 5.80),
@ -55,39 +65,32 @@ public class LlamaModelPreferencesForm {
4, new ModelDetails(20.22, 22.72),
5, new ModelDetails(23.84, 26.34)));
private final TextFieldWithBrowseButton customModelPathBrowserButton;
private final TextFieldWithBrowseButton browsableCustomModelTextField;
private final ComboBox<LlamaModel> modelComboBox;
private final ComboBox<ModelSize> modelSizeComboBox;
private final ComboBox<HuggingFaceModel> huggingFaceModelComboBox;
private final ComboBox<PromptTemplate> promptTemplateComboBox;
private final JBLabel modelExistsIcon;
private final DefaultComboBoxModel<HuggingFaceModel> huggingFaceComboBoxModel;
private final JBCheckBox useCustomModelCheckBox;
private final JBLabel helpIcon;
private final JBLabel promptTemplateHelpIcon;
private final JPanel downloadModelActionLinkWrapper;
private final JBLabel progressLabel;
private final JBLabel modelDetailsLabel;
public TextFieldWithBrowseButton getCustomModelPathBrowserButton() {
return customModelPathBrowserButton;
}
public ComboBox<HuggingFaceModel> getHuggingFaceModelComboBox() {
return huggingFaceModelComboBox;
}
private final ComboBox<PromptTemplate> promptTemplateComboBox;
private final CardLayout cardLayout;
private final JBRadioButton predefinedModelRadioButton;
private final JBRadioButton customModelRadioButton;
public LlamaModelPreferencesForm() {
var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class);
var llamaSettings = LlamaSettingsState.getInstance();
customModelPathBrowserButton = createCustomModelPathBrowseButton(
llamaSettings.isUseCustomModel() && !llamaServerAgent.isServerRunning());
customModelPathBrowserButton.setText(llamaSettings.getCustomLlamaModelPath());
cardLayout = new CardLayout();
progressLabel = new JBLabel("");
progressLabel.setBorder(JBUI.Borders.emptyLeft(2));
progressLabel.setFont(JBUI.Fonts.smallFont());
modelExistsIcon = new JBLabel(Actions.Checked);
var llamaSettings = LlamaSettingsState.getInstance();
modelExistsIcon.setVisible(isModelExists(llamaSettings.getHuggingFaceModel()));
helpIcon = new JBLabel(General.ContextHelp);
promptTemplateHelpIcon = new JBLabel(General.ContextHelp);
huggingFaceComboBoxModel = new DefaultComboBoxModel<>();
var llm = llamaSettings.getHuggingFaceModel();
var llamaModel = LlamaModel.findByHuggingFaceModel(llm);
@ -111,6 +114,7 @@ public class LlamaModelPreferencesForm {
modelExistsIcon,
modelDetailsLabel,
downloadModelActionLinkWrapper);
var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class);
huggingFaceModelComboBox.setEnabled(!llamaServerAgent.isServerRunning());
var modelSizeComboBoxModel = new DefaultComboBoxModel<ModelSize>();
var initialModelSizes = llamaModel.getSortedUniqueModelSizes().stream()
@ -127,30 +131,143 @@ public class LlamaModelPreferencesForm {
huggingFaceComboBoxModel);
modelSizeComboBox.setEnabled(
initialModelSizes.size() > 1 && !llamaServerAgent.isServerRunning());
browsableCustomModelTextField = createBrowsableCustomModelTextField(
!llamaServerAgent.isServerRunning());
browsableCustomModelTextField.setText(llamaSettings.getCustomLlamaModelPath());
promptTemplateComboBox = new ComboBox<>(new EnumComboBoxModel<>(PromptTemplate.class));
promptTemplateComboBox.setSelectedItem(llamaSettings.getPromptTemplate());
promptTemplateComboBox.setEnabled(
llamaSettings.isUseCustomModel() && !llamaServerAgent.isServerRunning());
promptTemplateComboBox.setPreferredSize(modelComboBox.getPreferredSize());
useCustomModelCheckBox = new JBCheckBox(CodeGPTBundle.get(
"settingsConfigurable.service.llama.useCustomModel.label"),
llamaSettings.isUseCustomModel());
useCustomModelCheckBox.setEnabled(!llamaServerAgent.isServerRunning());
useCustomModelCheckBox.addChangeListener(e -> {
var selected = ((JBCheckBox) e.getSource()).isSelected();
customModelPathBrowserButton.setEnabled(selected && !llamaServerAgent.isServerRunning());
promptTemplateComboBox.setEnabled(selected && !llamaServerAgent.isServerRunning());
modelComboBox.setEnabled(!selected);
modelSizeComboBox.setEnabled((!selected));
huggingFaceModelComboBox.setEnabled((!selected));
promptTemplateComboBox.setEnabled(!llamaServerAgent.isServerRunning());
promptTemplateComboBox.addItemListener(item -> {
var template = (PromptTemplate) item.getItem();
updatePromptTemplateHelpTooltip(template);
});
updatePromptTemplateHelpTooltip(llamaSettings.getPromptTemplate());
predefinedModelRadioButton = new JBRadioButton("Use pre-defined model",
!llamaSettings.isUseCustomModel());
customModelRadioButton = new JBRadioButton("Use custom model",
llamaSettings.isUseCustomModel());
}
public JPanel getForm() {
JPanel finalPanel = new JPanel(new BorderLayout());
finalPanel.add(createRadioButtonsPanel(), BorderLayout.NORTH);
finalPanel.add(createFormPanelCards(), BorderLayout.CENTER);
return finalPanel;
}
public void enableFields(boolean enabled) {
modelComboBox.setEnabled(enabled);
modelSizeComboBox.setEnabled(enabled);
huggingFaceModelComboBox.setEnabled(enabled);
}
public TextFieldWithBrowseButton getBrowsableCustomModelTextField() {
return browsableCustomModelTextField;
}
public ComboBox<HuggingFaceModel> getHuggingFaceModelComboBox() {
return huggingFaceModelComboBox;
}
public void setSelectedModel(HuggingFaceModel model) {
huggingFaceComboBoxModel.setSelectedItem(model);
}
public HuggingFaceModel getSelectedModel() {
return (HuggingFaceModel) huggingFaceComboBoxModel.getSelectedItem();
}
public void setCustomLlamaModelPath(String modelPath) {
browsableCustomModelTextField.setText(modelPath);
}
public String getCustomLlamaModelPath() {
return browsableCustomModelTextField.getText();
}
public void setUseCustomLlamaModel(boolean useCustomLlamaModel) {
customModelRadioButton.setSelected(useCustomLlamaModel);
}
public boolean isUseCustomLlamaModel() {
return customModelRadioButton.isSelected();
}
public void setPromptTemplate(PromptTemplate promptTemplate) {
promptTemplateComboBox.setSelectedItem(promptTemplate);
}
public PromptTemplate getPromptTemplate() {
return promptTemplateComboBox.getItem();
}
public String getActualModelPath() {
return isUseCustomLlamaModel()
? getCustomLlamaModelPath()
: CodeGPTPlugin.getLlamaModelsPath() + File.separator + getSelectedModel().getFileName();
}
private JPanel createFormPanelCards() {
var formPanelCards = new JPanel(cardLayout);
formPanelCards.setBorder(JBUI.Borders.emptyLeft(16));
formPanelCards.add(createPredefinedModelForm(), PREDEFINED_MODEL_FORM_CARD_CODE);
formPanelCards.add(createCustomModelForm(), CUSTOM_MODEL_FORM_CARD_CODE);
cardLayout.show(
formPanelCards,
predefinedModelRadioButton.isSelected()
? PREDEFINED_MODEL_FORM_CARD_CODE
: CUSTOM_MODEL_FORM_CARD_CODE);
predefinedModelRadioButton.addActionListener(e ->
cardLayout.show(formPanelCards, PREDEFINED_MODEL_FORM_CARD_CODE));
customModelRadioButton.addActionListener(e ->
cardLayout.show(formPanelCards, CUSTOM_MODEL_FORM_CARD_CODE));
return formPanelCards;
}
private JPanel createRadioButtonsPanel() {
var buttonGroup = new ButtonGroup();
buttonGroup.add(predefinedModelRadioButton);
buttonGroup.add(customModelRadioButton);
var radioPanel = new JPanel();
radioPanel.setLayout(new BoxLayout(radioPanel, BoxLayout.PAGE_AXIS));
radioPanel.add(predefinedModelRadioButton);
radioPanel.add(Box.createVerticalStrut(4));
radioPanel.add(customModelRadioButton);
radioPanel.add(Box.createVerticalStrut(8));
return radioPanel;
}
private JPanel createCustomModelForm() {
var customModelHelpText = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.customModelPath.comment"),
true);
customModelHelpText.setBorder(JBUI.Borders.empty(0, 4));
var promptTemplateHelpText = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.promptTemplate.comment"),
true);
promptTemplateHelpText.setBorder(JBUI.Borders.empty(0, 4));
var promptTemplateWrapper = new JPanel(new FlowLayout(FlowLayout.LEADING, 0, 0));
promptTemplateWrapper.add(promptTemplateComboBox);
promptTemplateWrapper.add(Box.createHorizontalStrut(8));
promptTemplateWrapper.add(promptTemplateHelpIcon);
return FormBuilder.createFormBuilder()
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.customModelPath.label"),
browsableCustomModelTextField)
.addComponentToRightColumn(customModelHelpText)
.addLabeledComponent(CodeGPTBundle.get("shared.promptTemplate"), promptTemplateWrapper)
.addComponentToRightColumn(promptTemplateHelpText)
.addVerticalGap(4)
.addComponentFillVertically(new JPanel(), 0)
.getPanel();
}
private JPanel createPredefinedModelForm() {
var quantizationHelpText = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.quantization.comment"),
true);
@ -180,39 +297,11 @@ public class LlamaModelPreferencesForm {
.addComponentToRightColumn(quantizationHelpText)
.addComponentToRightColumn(downloadModelActionLinkWrapper)
.addComponentToRightColumn(progressLabel)
.addVerticalGap(8)
.addComponent(useCustomModelCheckBox)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.promptTemplate.label"),
promptTemplateComboBox)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.customModelPath.label"),
customModelPathBrowserButton)
.addComponentToRightColumn(customModelHelpText)
.addVerticalGap(4)
.addComponentFillVertically(new JPanel(), 0)
.getPanel();
}
public void enableFields(boolean enabled) {
modelComboBox.setEnabled(enabled);
modelSizeComboBox.setEnabled(enabled);
huggingFaceModelComboBox.setEnabled(enabled);
useCustomModelCheckBox.setEnabled(enabled);
promptTemplateComboBox.setEnabled(enabled && useCustomModelCheckBox.isSelected());
customModelPathBrowserButton.setEnabled(enabled && useCustomModelCheckBox.isSelected());
}
private static class ModelDetails {
double fileSize;
double maxRAMRequired;
public ModelDetails(double fileSize, double maxRAMRequired) {
this.fileSize = fileSize;
this.maxRAMRequired = maxRAMRequired;
}
}
private String getHuggingFaceModelDetailsHtml(HuggingFaceModel model) {
int parameterSize = model.getParameterSize();
int quantization = model.getQuantization();
@ -232,44 +321,6 @@ public class LlamaModelPreferencesForm {
+ "</html>", details.fileSize, details.maxRAMRequired);
}
public void setSelectedModel(HuggingFaceModel model) {
huggingFaceComboBoxModel.setSelectedItem(model);
}
public HuggingFaceModel getSelectedModel() {
return (HuggingFaceModel) huggingFaceComboBoxModel.getSelectedItem();
}
public void setCustomLlamaModelPath(String modelPath) {
customModelPathBrowserButton.setText(modelPath);
}
public String getCustomLlamaModelPath() {
return customModelPathBrowserButton.getText();
}
public void setUseCustomLlamaModel(boolean useCustomLlamaModel) {
useCustomModelCheckBox.setSelected(useCustomLlamaModel);
}
public boolean isUseCustomLlamaModel() {
return useCustomModelCheckBox.isSelected();
}
public void setPromptTemplate(PromptTemplate promptTemplate) {
promptTemplateComboBox.setSelectedItem(promptTemplate);
}
public PromptTemplate getPromptTemplate() {
return promptTemplateComboBox.getItem();
}
public String getActualModelPath() {
return isUseCustomLlamaModel()
? getCustomLlamaModelPath()
: CodeGPTPlugin.getLlamaModelsPath() + File.separator + getSelectedModel().getFileName();
}
private ComboBox<LlamaModel> createModelComboBox(
EnumComboBoxModel<LlamaModel> llamaModelEnumComboBoxModel,
LlamaModel llamaModel,
@ -345,7 +396,7 @@ public class LlamaModelPreferencesForm {
return comboBox;
}
private TextFieldWithBrowseButton createCustomModelPathBrowseButton(boolean enabled) {
private TextFieldWithBrowseButton createBrowsableCustomModelTextField(boolean enabled) {
var browseButton = new TextFieldWithBrowseButton();
browseButton.setEnabled(enabled);
@ -445,14 +496,42 @@ public class LlamaModelPreferencesForm {
var llamaModel = LlamaModel.findByHuggingFaceModel(model);
new HelpTooltip()
.setTitle(llamaModel.getLabel())
.setDescription("<html><p>" + llamaModel.getDescription() + "</p></html>")
.setDescription(llamaModel.getDescription())
.setBrowserLink(
CodeGPTBundle.get("settingsConfigurable.service.llama.linkToModel.label"),
model.getHuggingFaceURL())
.installOn(helpIcon);
}
static class ModelSize {
private void updatePromptTemplateHelpTooltip(PromptTemplate template) {
promptTemplateHelpIcon.setToolTipText(null);
var prompt = template.buildPrompt(
"SYSTEM_PROMPT",
"USER_PROMPT",
List.of(new Message("PREV_PROMPT", "PREV_RESPONSE")));
var htmlDescription = Arrays.stream(prompt.split("\n"))
.map(StringEscapeUtils::escapeHtml4)
.collect(Collectors.joining("<br>"));
new HelpTooltip()
.setTitle(template.toString())
.setDescription("<html><p>" + htmlDescription + "</p></html>")
.installOn(promptTemplateHelpIcon);
}
private static class ModelDetails {
double fileSize;
double maxRAMRequired;
public ModelDetails(double fileSize, double maxRAMRequired) {
this.fileSize = fileSize;
this.maxRAMRequired = maxRAMRequired;
}
}
private static class ModelSize {
private final int size;

View file

@ -0,0 +1,97 @@
package ee.carlrobert.codegpt.settings.service;
import com.intellij.openapi.ui.panel.ComponentPanelBuilder;
import com.intellij.ui.components.JBTextField;
import com.intellij.ui.components.fields.IntegerField;
import com.intellij.util.ui.FormBuilder;
import com.intellij.util.ui.JBUI;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import javax.swing.JLabel;
import javax.swing.JPanel;
public class LlamaRequestPreferencesForm {
private final IntegerField topKField;
private final JBTextField topPField;
private final JBTextField minPField;
private final JBTextField repeatPenaltyField;
public LlamaRequestPreferencesForm() {
var llamaSettings = LlamaSettingsState.getInstance();
topKField = new IntegerField();
topKField.setColumns(12);
topKField.setValue(llamaSettings.getTopK());
topPField = new JBTextField(12);
topPField.setText(String.valueOf(llamaSettings.getTopP()));
minPField = new JBTextField(12);
minPField.setText(String.valueOf(llamaSettings.getMinP()));
repeatPenaltyField = new JBTextField(12);
repeatPenaltyField.setText(String.valueOf(llamaSettings.getRepeatPenalty()));
}
public JPanel getForm() {
return FormBuilder.createFormBuilder()
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.topK.label"),
topKField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.topK.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.topP.label"),
topPField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.topP.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.minP.label"),
minPField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.minP.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.repeatPenalty.label"),
repeatPenaltyField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.repeatPenalty.comment"))
.addComponentFillVertically(new JPanel(), 0)
.getPanel();
}
public int getTopK() {
return topKField.getValue();
}
public void setTopK(int topK) {
topKField.setValue(topK);
}
public double getTopP() {
return Double.parseDouble(topPField.getText());
}
public void setTopP(double topP) {
topPField.setText(String.valueOf(topP));
}
public double getMinP() {
return Double.parseDouble(minPField.getText());
}
public void setMinP(double minP) {
minPField.setText(String.valueOf(minP));
}
public double getRepeatPenalty() {
return Double.parseDouble(repeatPenaltyField.getText());
}
public void setRepeatPenalty(double repeatPenalty) {
repeatPenaltyField.setText(String.valueOf(repeatPenalty));
}
private JLabel createComment(String messageKey) {
var comment = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get(messageKey), true);
comment.setBorder(JBUI.Borders.empty(0, 4));
return comment;
}
}

View file

@ -35,6 +35,7 @@ import javax.swing.SwingConstants;
public class LlamaServiceSelectionForm extends JPanel {
private final LlamaModelPreferencesForm llamaModelPreferencesForm;
private final LlamaRequestPreferencesForm llamaRequestPreferencesForm;
private final PortField portField;
private final IntegerField maxTokensField;
private final IntegerField threadsField;
@ -48,6 +49,7 @@ public class LlamaServiceSelectionForm extends JPanel {
portField.setEnabled(!serverRunning);
llamaModelPreferencesForm = new LlamaModelPreferencesForm();
llamaRequestPreferencesForm = new LlamaRequestPreferencesForm();
var llamaSettings = LlamaSettingsState.getInstance();
maxTokensField = new IntegerField("max_tokens", 256, 4096);
@ -78,6 +80,10 @@ public class LlamaServiceSelectionForm extends JPanel {
return llamaModelPreferencesForm;
}
public LlamaRequestPreferencesForm getLlamaRequestPreferencesForm() {
return llamaRequestPreferencesForm;
}
private JComponent withEmptyLeftBorder(JComponent component) {
component.setBorder(JBUI.Borders.emptyLeft(16));
return component;
@ -119,6 +125,7 @@ public class LlamaServiceSelectionForm extends JPanel {
private void init(LlamaServerAgent llamaServerAgent) {
var serverProgressPanel = new ServerProgressPanel();
serverProgressPanel.setBorder(JBUI.Borders.emptyRight(16));
setLayout(new BorderLayout());
add(FormBuilder.createFormBuilder()
.addComponent(new TitledSeparator(
@ -127,6 +134,14 @@ public class LlamaServiceSelectionForm extends JPanel {
.addComponent(new TitledSeparator(
CodeGPTBundle.get("settingsConfigurable.service.llama.serverPreferences.title")))
.addComponent(withEmptyLeftBorder(FormBuilder.createFormBuilder()
.addLabeledComponent(
CodeGPTBundle.get("shared.port"),
JBUI.Panels.simplePanel()
.addToLeft(portField)
.addToRight(JBUI.Panels.simplePanel()
.addToCenter(serverProgressPanel)
.addToRight(getServerButton(llamaServerAgent, serverProgressPanel))))
.addVerticalGap(4)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.contextSize.label"),
maxTokensField)
@ -142,14 +157,10 @@ public class LlamaServiceSelectionForm extends JPanel {
additionalParametersField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.additionalParameters.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.port.label"),
JBUI.Panels.simplePanel()
.addToLeft(portField)
.addToRight(getServerButton(llamaServerAgent, serverProgressPanel)))
.addVerticalGap(8)
.getPanel()))
.addVerticalGap(4)
.addComponent(withEmptyLeftBorder(serverProgressPanel))
.addComponent(new TitledSeparator("Request Preferences"))
.addComponent(withEmptyLeftBorder(llamaRequestPreferencesForm.getForm()))
.addComponentFillVertically(new JPanel(), 0)
.getPanel());
}
@ -221,7 +232,7 @@ public class LlamaServiceSelectionForm extends JPanel {
OverlayUtil.showBalloon(
CodeGPTBundle.get("validation.error.fieldRequired"),
MessageType.ERROR,
llamaModelPreferencesForm.getCustomModelPathBrowserButton());
llamaModelPreferencesForm.getBrowsableCustomModelTextField());
return false;
}
}
@ -229,8 +240,8 @@ public class LlamaServiceSelectionForm extends JPanel {
}
private boolean validateSelectedModel() {
if (!llamaModelPreferencesForm.isUseCustomLlamaModel() && !isModelExists(
llamaModelPreferencesForm.getSelectedModel())) {
if (!llamaModelPreferencesForm.isUseCustomLlamaModel()
&& !isModelExists(llamaModelPreferencesForm.getSelectedModel())) {
OverlayUtil.showBalloon(
CodeGPTBundle.get("settingsConfigurable.service.llama.overlay.modelNotDownloaded.text"),
MessageType.ERROR,

View file

@ -12,10 +12,8 @@ public class ServerProgressPanel extends JPanel {
private final JBLabel label = new JBLabel();
public ServerProgressPanel() {
super(new FlowLayout(FlowLayout.LEADING, 0, 0));
setVisible(false);
add(new AsyncProcessIcon("sign_in_spinner"));
add(Box.createHorizontalStrut(4));
add(label);
}

View file

@ -365,6 +365,10 @@ public class ServiceSelectionForm {
return llamaServiceSectionPanel.getLlamaModelPreferencesForm();
}
public LlamaRequestPreferencesForm getLlamaRequestPreferencesForm() {
return llamaServiceSectionPanel.getLlamaRequestPreferencesForm();
}
public void setOpenAIPath(String path) {
openAIPathField.setText(path);
}

View file

@ -23,6 +23,10 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
private int contextSize = 2048;
private int threads = 8;
private String additionalParameters = "";
private int topK = 40;
private double topP = 0.9;
private double minP = 0.05;
private double repeatPenalty = 1.1;
public LlamaSettingsState() {
}
@ -43,11 +47,16 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
public boolean isModified(ServiceSelectionForm serviceSelectionForm) {
var modelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm();
var requestPreferencesForm = serviceSelectionForm.getLlamaRequestPreferencesForm();
return serverPort != serviceSelectionForm.getLlamaServerPort()
|| contextSize != serviceSelectionForm.getContextSize()
|| threads != serviceSelectionForm.getThreads()
|| !additionalParameters.equals(serviceSelectionForm.getAdditionalParameters())
|| huggingFaceModel != modelPreferencesForm.getSelectedModel()
|| topK != requestPreferencesForm.getTopK()
|| topP != requestPreferencesForm.getTopP()
|| minP != requestPreferencesForm.getMinP()
|| repeatPenalty != requestPreferencesForm.getRepeatPenalty()
|| !promptTemplate.equals(modelPreferencesForm.getPromptTemplate())
|| useCustomModel != modelPreferencesForm.isUseCustomLlamaModel()
|| !customLlamaModelPath.equals(modelPreferencesForm.getCustomLlamaModelPath());
@ -59,6 +68,11 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
huggingFaceModel = modelPreferencesForm.getSelectedModel();
useCustomModel = modelPreferencesForm.isUseCustomLlamaModel();
promptTemplate = modelPreferencesForm.getPromptTemplate();
var requestPreferencesForm = serviceSelectionForm.getLlamaRequestPreferencesForm();
topK = requestPreferencesForm.getTopK();
topP = requestPreferencesForm.getTopP();
minP = requestPreferencesForm.getMinP();
repeatPenalty = requestPreferencesForm.getRepeatPenalty();
serverPort = serviceSelectionForm.getLlamaServerPort();
contextSize = serviceSelectionForm.getContextSize();
threads = serviceSelectionForm.getThreads();
@ -71,6 +85,11 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
modelPreferencesForm.setCustomLlamaModelPath(customLlamaModelPath);
modelPreferencesForm.setUseCustomLlamaModel(useCustomModel);
modelPreferencesForm.setPromptTemplate(promptTemplate);
var requestPreferencesForm = serviceSelectionForm.getLlamaRequestPreferencesForm();
requestPreferencesForm.setTopK(topK);
requestPreferencesForm.setTopP(topP);
requestPreferencesForm.setMinP(minP);
requestPreferencesForm.setRepeatPenalty(repeatPenalty);
serviceSelectionForm.setLlamaServerPort(serverPort);
serviceSelectionForm.setContextSize(contextSize);
serviceSelectionForm.setThreads(threads);
@ -141,6 +160,38 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
this.additionalParameters = additionalParameters;
}
public int getTopK() {
return topK;
}
public void setTopK(int topK) {
this.topK = topK;
}
public double getTopP() {
return topP;
}
public void setTopP(double topP) {
this.topP = topP;
}
public double getMinP() {
return minP;
}
public void setMinP(double minP) {
this.minP = minP;
}
public double getRepeatPenalty() {
return repeatPenalty;
}
public void setRepeatPenalty(double repeatPenalty) {
this.repeatPenalty = repeatPenalty;
}
private static Integer getRandomAvailablePortOrDefault() {
try (ServerSocket socket = new ServerSocket(0)) {
return socket.getLocalPort();

View file

@ -7,7 +7,6 @@ action.generateCommitMessage.missingCredentials=Credentials not provided
action.includeFilesInContext.title=Include In Context...
action.includeFilesInContext.dialog.title=Include In Context
action.includeFilesInContext.dialog.description=Choose the files that you wish to include in the final prompt
action.includeFilesInContext.dialog.promptTemplate.label=Prompt template:
action.includeFilesInContext.dialog.repeatableContext.label=Repeatable context:
action.includeFilesInContext.dialog.restoreToDefaults.label=Restore to Defaults
settings.displayName=CodeGPT: Settings
@ -39,10 +38,9 @@ settingsConfigurable.service.llama.serverPreferences.title=Server Preferences
settingsConfigurable.service.llama.modelSize.label=Model size:
settingsConfigurable.service.llama.quantization.label=Quantization:
settingsConfigurable.service.llama.quantization.comment=Quantization is a technique to reduce the computational and memory costs of running inference. <a href="https://huggingface.co/docs/optimum/concept_guides/quantization">Learn more</a>
settingsConfigurable.service.llama.promptTemplate.label=Prompt template:
settingsConfigurable.service.llama.useCustomModel.label=Use custom model
settingsConfigurable.service.llama.customModelPath.label=Model path:
settingsConfigurable.service.llama.customModelPath.comment=Only .gguf files are supported
settingsConfigurable.service.llama.promptTemplate.comment=Choose the template to use during interactions with the language model. Make sure it matches the custom model you're working with.
settingsConfigurable.service.llama.downloadModelLink.label=Download Model
settingsConfigurable.service.llama.cancelDownloadLink.label=Cancel Downloading
settingsConfigurable.service.llama.linkToModel.label=Link to model
@ -52,7 +50,6 @@ settingsConfigurable.service.llama.threads.label=Threads:
settingsConfigurable.service.llama.threads.comment=The number of threads available to execute the model. It is not recommended to specify a number greater than the number of processor cores.
settingsConfigurable.service.llama.additionalParameters.label=Additional parameters:
settingsConfigurable.service.llama.additionalParameters.comment=<html>Additional command-line parameters for the server startup process, separated by commas. See the full <a href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md">list of options</a>.<p><i>Example: "--n-gpu-layers, 1, --no-mmap, --mlock"</i></p></html>
settingsConfigurable.service.llama.port.label=Port:
settingsConfigurable.service.llama.startServer.label=Start server
settingsConfigurable.service.llama.stopServer.label=Stop server
settingsConfigurable.service.llama.progress.serverRunning=Server running
@ -87,14 +84,14 @@ configurationConfigurable.section.assistant.temperatureField.comment=The value o
configurationConfigurable.section.assistant.maxTokensField.label=Max completion tokens:
configurationConfigurable.section.assistant.maxTokensField.comment=The maximum capacity for completion.
configurationConfigurable.section.assistant.llamacppParams.title=Configuration Options for llama.cpp
configurationConfigurable.section.assistant.topKField.label=Top-k:
configurationConfigurable.section.assistant.topKField.comment=Limit the next token selection to the K most probable tokens (default: 40)
configurationConfigurable.section.assistant.topPField.label=Top-p:
configurationConfigurable.section.assistant.topPField.comment=Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9)
configurationConfigurable.section.assistant.minPField.label=Min-p:
configurationConfigurable.section.assistant.minPField.comment=Sets a minimum base probability threshold for token selection (default: 0.05)
configurationConfigurable.section.assistant.repeatPenaltyField.label=Repeat penalty:
configurationConfigurable.section.assistant.repeatPenaltyField.comment=Control the repetition of token sequences in the generated text (default: 1.1)
settingsConfigurable.service.llama.topK.label=Top K:
settingsConfigurable.service.llama.topK.comment=Limit the next token selection to the K most probable tokens (default: 40)
settingsConfigurable.service.llama.topP.label=Top P:
settingsConfigurable.service.llama.topP.comment=Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9)
settingsConfigurable.service.llama.minP.label=Min P:
settingsConfigurable.service.llama.minP.comment=Sets a minimum base probability threshold for token selection (default: 0.05)
settingsConfigurable.service.llama.repeatPenalty.label=Repeat penalty:
settingsConfigurable.service.llama.repeatPenalty.comment=Control the repetition of token sequences in the generated text (default: 1.1)
configurationConfigurable.section.commitMessage.title=Commit Message
configurationConfigurable.section.commitMessage.systemPromptField.label=Prompt:
configurationConfigurable.section.commitMessage.systemPromptField.comment=Custom system prompt used for commit message generation.
@ -102,7 +99,6 @@ advancedSettingsConfigurable.displayName=CodeGPT: Advanced Settings
advancedSettingsConfigurable.proxy.title=HTTP/SOCKS Proxy
advancedSettingsConfigurable.proxy.typeComboBoxField.label=Proxy:
advancedSettingsConfigurable.proxy.hostField.label=Host name:
advancedSettingsConfigurable.proxy.portField.label=Port:
advancedSettingsConfigurable.proxy.authCheckBoxField.label=Proxy authentication
advancedSettingsConfigurable.proxy.usernameField.label=Username:
advancedSettingsConfigurable.proxy.passwordField.label=Password:
@ -156,4 +152,6 @@ checkForUpdatesTask.notification.message=An update for CodeGPT is available.
checkForUpdatesTask.notification.installButton=Install update
checkForUpdatesTask.notification.hideButton=Do not show again
llamaServerAgent.buildingProject.description=Building llama.cpp...
llamaServerAgent.serverBootup.description=Booting up server...
llamaServerAgent.serverBootup.description=Booting up server...
shared.promptTemplate=Prompt template:
shared.port=Port:

View file

@ -1,6 +1,8 @@
package ee.carlrobert.codegpt.toolwindow.chat;
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
import static ee.carlrobert.llm.client.util.JSONUtil.e;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
@ -11,13 +13,14 @@ import static org.awaitility.Awaitility.await;
import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.toolwindow.chat.standard.StandardChatToolWindowTabPanel;
import ee.carlrobert.embedding.CheckedFile;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import testsupport.IntegrationTest;
@ -166,4 +169,74 @@ public class StandardChatToolWindowTabPanelTest extends IntegrationTest {
message.getResponse(),
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
}
public void testSendingLlamaMessage() {
useLlamaService();
var configurationState = ConfigurationState.getInstance();
configurationState.setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
configurationState.setMaxTokens(1000);
configurationState.setTemperature(0.1);
var llamaSettings = LlamaSettingsState.getInstance();
llamaSettings.setUseCustomModel(false);
llamaSettings.setHuggingFaceModel(HuggingFaceModel.CODE_LLAMA_7B_Q4);
llamaSettings.setTopK(30);
llamaSettings.setTopP(0.8);
llamaSettings.setMinP(0.03);
llamaSettings.setRepeatPenalty(1.3);
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
var panel = new StandardChatToolWindowTabPanel(getProject(), conversation);
expectLlama((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/completion");
assertThat(request.getBody())
.extracting(
"prompt",
"n_predict",
"stream",
"temperature",
"top_k",
"top_p",
"min_p",
"repeat_penalty")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_PROMPT",
conversation.getMessages()),
configurationState.getMaxTokens(),
true,
configurationState.getTemperature(),
llamaSettings.getTopK(),
llamaSettings.getTopP(),
llamaSettings.getMinP(),
llamaSettings.getRepeatPenalty());
return List.of(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)));
});
panel.sendMessage(message);
await().atMost(5, SECONDS)
.until(() -> {
var messages = conversation.getMessages();
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
});
assertThat(panel.getConversation())
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.getId(),
conversation.getModel(),
conversation.getClientCode(),
false);
var messages = panel.getConversation().getMessages();
assertThat(messages.size()).isOne();
assertThat(messages.get(0))
.extracting("id", "prompt", "response")
.containsExactly(message.getId(), message.getPrompt(), message.getResponse());
}
}