From 9d83107dd5788fdf4e9923dff8b16d4522b5f682 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aliet=20Exp=C3=B3sito=20Garc=C3=ADa?= Date: Mon, 18 Dec 2023 04:53:23 -0500 Subject: [PATCH] Add support for some extended parameters of llama.cpp(top_k, top_p, min_p, and repeat_penalty) (#311) * Add support for some extended parameters of llama.cpp(top_k, top_p, min_p, and repeat_penalty) Added 'top_k,' 'top_p,' 'min_p,' and 'repeat_penalty' fields to the llama.cpp request configuration. The default values for these fields match the defaults of llama.cpp. If left untouched, they do not affect the model's response to the request. * Bump llm-client --------- Co-authored-by: Carl-Robert Linnupuu --- .../codegpt.java-conventions.gradle.kts | 2 +- .../CompletionRequestProvider.java | 4 + .../configuration/ConfigurationComponent.java | 80 +++++++++++++++++++ .../ConfigurationConfigurable.java | 12 +++ .../configuration/ConfigurationState.java | 36 +++++++++ .../resources/messages/codegpt.properties | 9 +++ 6 files changed, 142 insertions(+), 1 deletion(-) diff --git a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts index 49380b3d..ee5cb1d4 100644 --- a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts @@ -23,7 +23,7 @@ checkstyle { } dependencies { - implementation("ee.carlrobert:llm-client:0.1.2") + implementation("ee.carlrobert:llm-client:0.1.3") } tasks { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index c63bdc64..c89e75f2 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -109,6 +109,10 @@ public class CompletionRequestProvider { return new LlamaCompletionRequest.Builder(prompt) .setN_predict(configuration.getMaxTokens()) .setTemperature(configuration.getTemperature()) + .setTop_k(configuration.getTopK()) + .setTop_p(configuration.getTopP()) + .setMin_p(configuration.getMinP()) + .setRepeat_penalty(configuration.getRepeatPenalty()) .build(); } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java index ab51ac2d..8ecb3d36 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java @@ -50,6 +50,10 @@ public class ConfigurationComponent { private final JTextArea commitMessagePromptTextArea; private final IntegerField maxTokensField; private final JBTextField temperatureField; + private final IntegerField topKField; + private final JBTextField topPField; + private final JBTextField minPField; + private final JBTextField repeatPenaltyField; public ConfigurationComponent(Disposable parentDisposable, ConfigurationState configuration) { table = new JBTable(new DefaultTableModel( @@ -68,6 +72,19 @@ public class ConfigurationComponent { temperatureField = new JBTextField(12); temperatureField.setText(String.valueOf(configuration.getTemperature())); + topKField = new IntegerField(); + topKField.setColumns(12); + topKField.setValue(configuration.getTopK()); + + topPField = new JBTextField(12); + topPField.setText(String.valueOf(configuration.getTopP())); + + minPField = new JBTextField(12); + minPField.setText(String.valueOf(configuration.getMinP())); + + repeatPenaltyField = new JBTextField(12); + repeatPenaltyField.setText(String.valueOf(configuration.getRepeatPenalty())); + var temperatureFieldValidator = createInputValidator(parentDisposable, temperatureField); temperatureField.getDocument().addDocumentListener(new DocumentListener() { @Override @@ -131,6 +148,9 @@ public class ConfigurationComponent { CodeGPTBundle.get("configurationConfigurable.section.assistant.title"))) .addComponent(createAssistantConfigurationForm()) .addComponentFillVertically(new JPanel(), 0) + .addComponent(new TitledSeparator( + CodeGPTBundle.get("configurationConfigurable.section.assistant.llamacppParams.title"))) + .addComponent(createLlamaAssistantConfigurationForm()) .addComponent(new TitledSeparator( CodeGPTBundle.get("configurationConfigurable.section.commitMessage.title"))) .addComponent(createCommitMessageConfigurationForm()) @@ -210,6 +230,34 @@ public class ConfigurationComponent { return form; } + private JPanel createLlamaAssistantConfigurationForm() { + var formBuilder = FormBuilder.createFormBuilder(); + addAssistantFormLabeledComponent( + formBuilder, + "configurationConfigurable.section.assistant.topKField.label", + "configurationConfigurable.section.assistant.topKField.comment", + topKField); + addAssistantFormLabeledComponent( + formBuilder, + "configurationConfigurable.section.assistant.topPField.label", + "configurationConfigurable.section.assistant.topPField.comment", + topPField); + addAssistantFormLabeledComponent( + formBuilder, + "configurationConfigurable.section.assistant.minPField.label", + "configurationConfigurable.section.assistant.minPField.comment", + minPField); + addAssistantFormLabeledComponent( + formBuilder, + "configurationConfigurable.section.assistant.repeatPenaltyField.label", + "configurationConfigurable.section.assistant.repeatPenaltyField.comment", + repeatPenaltyField); + + var form = formBuilder.getPanel(); + form.setBorder(JBUI.Borders.emptyLeft(16)); + return form; + } + private JPanel createCommitMessageConfigurationForm() { var formBuilder = FormBuilder.createFormBuilder(); addAssistantFormLabeledComponent( @@ -296,6 +344,38 @@ public class ConfigurationComponent { maxTokensField.setValue(maxTokens); } + public int getTopK() { + return topKField.getValue(); + } + + public void setTopK(int topK) { + topKField.setValue(topK); + } + + public double getTopP() { + return Double.parseDouble(topPField.getText()); + } + + public void setTopP(double topP) { + topPField.setText(String.valueOf(topP)); + } + + public double getMinP() { + return Double.parseDouble(minPField.getText()); + } + + public void setMinP(double minP) { + minPField.setText(String.valueOf(minP)); + } + + public double getRepeatPenalty() { + return Double.parseDouble(repeatPenaltyField.getText()); + } + + public void setRepeatPenalty(double repeatPenalty) { + repeatPenaltyField.setText(String.valueOf(repeatPenalty)); + } + public boolean isCheckForPluginUpdates() { return checkForPluginUpdatesCheckBox.isSelected(); } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationConfigurable.java b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationConfigurable.java index baa62886..0e9e23a6 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationConfigurable.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationConfigurable.java @@ -36,6 +36,10 @@ public class ConfigurationConfigurable implements Configurable { return !configurationComponent.getTableData().equals(configuration.getTableData()) || configurationComponent.getMaxTokens() != configuration.getMaxTokens() || configurationComponent.getTemperature() != configuration.getTemperature() + || configurationComponent.getTopK() != configuration.getTopK() + || configurationComponent.getTopP() != configuration.getTopP() + || configurationComponent.getMinP() != configuration.getMinP() + || configurationComponent.getRepeatPenalty() != configuration.getRepeatPenalty() || !configurationComponent.getSystemPrompt().equals(configuration.getSystemPrompt()) || !configurationComponent.getCommitMessagePrompt() .equals(configuration.getCommitMessagePrompt()) @@ -55,6 +59,10 @@ public class ConfigurationConfigurable implements Configurable { configuration.setTableData(configurationComponent.getTableData()); configuration.setMaxTokens(configurationComponent.getMaxTokens()); configuration.setTemperature(configurationComponent.getTemperature()); + configuration.setTopK(configurationComponent.getTopK()); + configuration.setTopP(configurationComponent.getTopP()); + configuration.setMinP(configurationComponent.getMinP()); + configuration.setRepeatPenalty(configurationComponent.getRepeatPenalty()); configuration.setSystemPrompt(configurationComponent.getSystemPrompt()); configuration.setCommitMessagePrompt(configurationComponent.getCommitMessagePrompt()); configuration.setCheckForPluginUpdates(configurationComponent.isCheckForPluginUpdates()); @@ -72,6 +80,10 @@ public class ConfigurationConfigurable implements Configurable { configurationComponent.setTableData(configuration.getTableData()); configurationComponent.setMaxTokens(configuration.getMaxTokens()); configurationComponent.setTemperature(configuration.getTemperature()); + configurationComponent.setTopK(configuration.getTopK()); + configurationComponent.setTopP(configuration.getTopP()); + configurationComponent.setMinP(configuration.getMinP()); + configurationComponent.setRepeatPenalty(configuration.getRepeatPenalty()); configurationComponent.setSystemPrompt(configuration.getSystemPrompt()); configurationComponent.setCommitMessagePrompt(configuration.getCommitMessagePrompt()); configurationComponent.setCheckForPluginUpdates(configuration.isCheckForPluginUpdates()); diff --git a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationState.java b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationState.java index 00a2d289..2e08c93a 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationState.java @@ -22,6 +22,10 @@ public class ConfigurationState implements PersistentStateComponent