Support additional command-line params for the server startup process

This commit is contained in:
Carl-Robert Linnupuu 2023-11-26 13:21:02 +02:00
parent 1df20ccb86
commit 01963e2faa
6 changed files with 105 additions and 43 deletions

View file

@ -20,6 +20,7 @@ import ee.carlrobert.codegpt.settings.service.LlamaServiceSelectionForm;
import ee.carlrobert.codegpt.settings.service.ServerProgressPanel;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import java.nio.charset.StandardCharsets;
import java.util.List;
import javax.swing.SwingConstants;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@ -80,14 +81,16 @@ public final class LlamaServerAgent implements Disposable {
serviceSelectionForm.getLlamaModelPreferencesForm().getActualModelPath(),
serviceSelectionForm.getContextSize(),
serviceSelectionForm.getThreads(),
serviceSelectionForm.getServerPort()));
serviceSelectionForm.getServerPort(),
serviceSelectionForm.getListOfAdditionalParameters()));
startServerProcessHandler.addProcessListener(getProcessListener(
serviceSelectionForm.getServerPort(),
serverProgressPanel,
onSuccess));
startServerProcessHandler.startNotify();
} catch (ExecutionException e) {
throw new RuntimeException(e);
} catch (ExecutionException ex) {
LOG.error("Unable to start the server", ex);
throw new RuntimeException(ex);
}
}
};
@ -140,7 +143,8 @@ public final class LlamaServerAgent implements Disposable {
String modelPath,
int contextLength,
int threads,
int port) {
int port,
List<String> additionalParameters) {
GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8);
commandLine.setExePath("./server");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
@ -149,6 +153,7 @@ public final class LlamaServerAgent implements Disposable {
"-c", String.valueOf(contextLength),
"--port", String.valueOf(port),
"-t", String.valueOf(threads));
commandLine.addParameters(additionalParameters);
commandLine.setRedirectErrorStream(false);
return commandLine;
}

View file

@ -1,5 +1,7 @@
package ee.carlrobert.codegpt.settings.service;
import static java.util.stream.Collectors.toList;
import com.intellij.icons.AllIcons.Actions;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.ui.MessageType;
@ -8,6 +10,7 @@ import com.intellij.openapi.util.io.FileUtil;
import com.intellij.ui.PortField;
import com.intellij.ui.TitledSeparator;
import com.intellij.ui.components.JBLabel;
import com.intellij.ui.components.JBTextField;
import com.intellij.ui.components.fields.IntegerField;
import com.intellij.util.ui.FormBuilder;
import com.intellij.util.ui.JBUI;
@ -19,8 +22,11 @@ import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.util.OverlayUtil;
import java.awt.BorderLayout;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import javax.swing.JButton;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.SwingConstants;
@ -30,6 +36,7 @@ public class LlamaServiceSelectionForm extends JPanel {
private final PortField portField;
private final IntegerField maxTokensField;
private final IntegerField threadsField;
private final JBTextField additionalParametersField;
public LlamaServiceSelectionForm() {
var llamaServerAgent =
@ -40,52 +47,21 @@ public class LlamaServiceSelectionForm extends JPanel {
llamaModelPreferencesForm = new LlamaModelPreferencesForm();
var llamaSettings = LlamaSettingsState.getInstance();
maxTokensField = new IntegerField("max_tokens", 256, 4096);
maxTokensField.setColumns(12);
maxTokensField.setValue(2048);
maxTokensField.setValue(llamaSettings.getContextSize());
maxTokensField.setEnabled(!serverRunning);
threadsField = new IntegerField("threads", 1, 256);
threadsField.setColumns(12);
threadsField.setValue(8);
threadsField.setValue(llamaSettings.getThreads());
threadsField.setEnabled(!serverRunning);
var serverProgressPanel = new ServerProgressPanel();
var serverButton = getServerButton(llamaServerAgent, serverProgressPanel);
var contextSizeHelpText = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.contextSize.comment"),
true);
contextSizeHelpText.setBorder(JBUI.Borders.empty(0, 4));
var threadsHelpText = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.threads.comment"),
true);
additionalParametersField = new JBTextField(llamaSettings.getAdditionalParameters(), 30);
additionalParametersField.setEnabled(!serverRunning);
setLayout(new BorderLayout());
add(FormBuilder.createFormBuilder()
.addComponent(new TitledSeparator(
CodeGPTBundle.get("settingsConfigurable.service.llama.modelPreferences.title")))
.addComponent(withEmptyLeftBorder(llamaModelPreferencesForm.getForm()))
.addComponent(new TitledSeparator(
CodeGPTBundle.get("settingsConfigurable.service.llama.serverPreferences.title")))
.addComponent(withEmptyLeftBorder(FormBuilder.createFormBuilder()
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.contextSize.label"),
maxTokensField)
.addComponentToRightColumn(contextSizeHelpText)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.threads.label"),
threadsField)
.addComponentToRightColumn(threadsHelpText)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.port.label"),
JBUI.Panels.simplePanel()
.addToLeft(portField)
.addToRight(serverButton))
.getPanel()))
.addVerticalGap(4)
.addComponent(withEmptyLeftBorder(serverProgressPanel))
.addComponentFillVertically(new JPanel(), 0)
.getPanel());
init(llamaServerAgent);
}
public void setServerPort(int serverPort) {
@ -121,6 +97,65 @@ public class LlamaServiceSelectionForm extends JPanel {
return threadsField.getValue();
}
public void setAdditionalParameters(String additionalParameters) {
additionalParametersField.setText(additionalParameters);
}
public String getAdditionalParameters() {
return additionalParametersField.getText();
}
public List<String> getListOfAdditionalParameters() {
var parameters = additionalParametersField.getText().split(",");
return Arrays.stream(parameters)
.map(String::trim)
.collect(toList());
}
private void init(LlamaServerAgent llamaServerAgent) {
var serverProgressPanel = new ServerProgressPanel();
setLayout(new BorderLayout());
add(FormBuilder.createFormBuilder()
.addComponent(new TitledSeparator(
CodeGPTBundle.get("settingsConfigurable.service.llama.modelPreferences.title")))
.addComponent(withEmptyLeftBorder(llamaModelPreferencesForm.getForm()))
.addComponent(new TitledSeparator(
CodeGPTBundle.get("settingsConfigurable.service.llama.serverPreferences.title")))
.addComponent(withEmptyLeftBorder(FormBuilder.createFormBuilder()
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.contextSize.label"),
maxTokensField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.contextSize.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.threads.label"),
threadsField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.threads.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.additionalParameters.label"),
additionalParametersField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.additionalParameters.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.port.label"),
JBUI.Panels.simplePanel()
.addToLeft(portField)
.addToRight(getServerButton(llamaServerAgent, serverProgressPanel)))
.getPanel()))
.addVerticalGap(4)
.addComponent(withEmptyLeftBorder(serverProgressPanel))
.addComponentFillVertically(new JPanel(), 0)
.getPanel());
}
private JLabel createComment(String messageKey) {
var comment = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get(messageKey), true);
comment.setBorder(JBUI.Borders.empty(0, 4));
return comment;
}
private JButton getServerButton(
LlamaServerAgent llamaServerAgent,
ServerProgressPanel serverProgressPanel) {
@ -208,5 +243,6 @@ public class LlamaServiceSelectionForm extends JPanel {
portField.setEnabled(enabled);
maxTokensField.setEnabled(enabled);
threadsField.setEnabled(enabled);
additionalParametersField.setEnabled(enabled);
}
}

View file

@ -420,4 +420,12 @@ public class ServiceSelectionForm {
public void setThreads(int threads) {
llamaServiceSectionPanel.setThreads(threads);
}
public String getAdditionalParameters() {
return llamaServiceSectionPanel.getAdditionalParameters();
}
public void setAdditionalParameters(String additionalParameters) {
llamaServiceSectionPanel.setAdditionalParameters(additionalParameters);
}
}

View file

@ -21,8 +21,8 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
private PromptTemplate promptTemplate = PromptTemplate.LLAMA;
private Integer serverPort = getRandomAvailablePortOrDefault();
private int contextSize = 2048;
private int threads = 8;
private String additionalParameters = "";
public LlamaSettingsState() {
}
@ -46,6 +46,7 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
return serverPort != serviceSelectionForm.getLlamaServerPort()
|| contextSize != serviceSelectionForm.getContextSize()
|| threads != serviceSelectionForm.getThreads()
|| !additionalParameters.equals(serviceSelectionForm.getAdditionalParameters())
|| huggingFaceModel != modelPreferencesForm.getSelectedModel()
|| !promptTemplate.equals(modelPreferencesForm.getPromptTemplate())
|| useCustomModel != modelPreferencesForm.isUseCustomLlamaModel()
@ -61,6 +62,7 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
serverPort = serviceSelectionForm.getLlamaServerPort();
contextSize = serviceSelectionForm.getContextSize();
threads = serviceSelectionForm.getThreads();
additionalParameters = serviceSelectionForm.getAdditionalParameters();
}
public void reset(ServiceSelectionForm serviceSelectionForm) {
@ -72,6 +74,7 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
serviceSelectionForm.setLlamaServerPort(serverPort);
serviceSelectionForm.setContextSize(contextSize);
serviceSelectionForm.setThreads(threads);
serviceSelectionForm.setAdditionalParameters(additionalParameters);
}
public boolean isUseCustomModel() {
@ -130,6 +133,14 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
this.threads = threads;
}
public String getAdditionalParameters() {
return additionalParameters;
}
public void setAdditionalParameters(String additionalParameters) {
this.additionalParameters = additionalParameters;
}
private static Integer getRandomAvailablePortOrDefault() {
try (ServerSocket socket = new ServerSocket(0)) {
return socket.getLocalPort();

View file

@ -20,7 +20,7 @@ public class UserPromptTextAreaHeader extends JPanel {
Runnable onAddNewTab) {
super(new BorderLayout());
setOpaque(false);
setBorder(JBUI.Borders.emptyBottom(4));
setBorder(JBUI.Borders.emptyBottom(8));
switch (selectedService) {
case OPENAI:
case AZURE:

View file

@ -43,6 +43,8 @@ settingsConfigurable.service.llama.contextSize.label=Prompt context size:
settingsConfigurable.service.llama.contextSize.comment=The size of the prompt context. LLaMA models were built with a context of 2048, which will provide better results for longer input/inference
settingsConfigurable.service.llama.threads.label=Threads:
settingsConfigurable.service.llama.threads.comment=The number of threads available to execute the model. It is not recommended to specify a number greater than the number of processor cores.
settingsConfigurable.service.llama.additionalParameters.label=Additional parameters
settingsConfigurable.service.llama.additionalParameters.comment=<html>Additional command-line parameters for the server startup process, separated by commas. See the full <a href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md">list of options</a><p>Example: "--n-gpu-layers, 1, --no-mmap, --mlock"</p></html>
settingsConfigurable.service.llama.port.label=Port:
settingsConfigurable.service.llama.startServer.label=Start server
settingsConfigurable.service.llama.stopServer.label=Stop server