feat: add input field for llama server build parameters and improve error handling (#481)

This commit is contained in:
Phil 2024-04-20 22:18:43 +02:00 committed by GitHub
parent 67dc425a94
commit c8181a62e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 138 additions and 49 deletions

View file

@ -14,14 +14,17 @@ import com.intellij.openapi.Disposable;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.ui.MessageType;
import com.intellij.openapi.util.Key;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.ServerProgressPanel;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@ -32,65 +35,94 @@ public final class LlamaServerAgent implements Disposable {
private @Nullable OSProcessHandler makeProcessHandler;
private @Nullable OSProcessHandler startServerProcessHandler;
private ServerProgressPanel activeServerProgressPanel;
private boolean stoppedByUser;
public void startAgent(
LlamaServerStartupParams params,
ServerProgressPanel serverProgressPanel,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
this.activeServerProgressPanel = serverProgressPanel;
ApplicationManager.getApplication().invokeLater(() -> {
try {
serverProgressPanel.updateText(
stoppedByUser = false;
serverProgressPanel.displayText(
CodeGPTBundle.get("llamaServerAgent.buildingProject.description"));
makeProcessHandler = new OSProcessHandler(getMakeCommandLinde());
makeProcessHandler = new OSProcessHandler(
getMakeCommandLine(params.additionalBuildParameters()));
makeProcessHandler.addProcessListener(
getMakeProcessListener(params, serverProgressPanel, onSuccess, onServerTerminated));
getMakeProcessListener(params, onSuccess, onServerTerminated));
makeProcessHandler.startNotify();
} catch (ExecutionException e) {
throw new RuntimeException(e);
showServerError(e.getMessage(), onServerTerminated);
}
});
}
public void stopAgent() {
stoppedByUser = true;
if (makeProcessHandler != null) {
makeProcessHandler.destroyProcess();
}
if (startServerProcessHandler != null) {
startServerProcessHandler.destroyProcess();
}
}
public boolean isServerRunning() {
return startServerProcessHandler != null
return (makeProcessHandler != null
&& makeProcessHandler.isStartNotified()
&& !makeProcessHandler.isProcessTerminated())
|| (startServerProcessHandler != null
&& startServerProcessHandler.isStartNotified()
&& !startServerProcessHandler.isProcessTerminated();
&& !startServerProcessHandler.isProcessTerminated());
}
private ProcessListener getMakeProcessListener(
LlamaServerStartupParams params,
ServerProgressPanel serverProgressPanel,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
LOG.info("Building llama project");
return new ProcessAdapter() {
private final List<String> errorLines = new CopyOnWriteArrayList<>();
@Override
public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) {
if (ProcessOutputType.isStderr(outputType)) {
errorLines.add(event.getText());
return;
}
LOG.info(event.getText());
}
@Override
public void processTerminated(@NotNull ProcessEvent event) {
int exitCode = event.getExitCode();
LOG.info(format("Server build exited with code %d", exitCode));
if (stoppedByUser) {
onServerTerminated.accept(activeServerProgressPanel);
return;
}
if (exitCode != 0) {
showServerError(String.join(",", errorLines), onServerTerminated);
return;
}
try {
LOG.info("Booting up llama server");
serverProgressPanel.updateText(
activeServerProgressPanel.displayText(
CodeGPTBundle.get("llamaServerAgent.serverBootup.description"));
startServerProcessHandler = new OSProcessHandler.Silent(getServerCommandLine(params));
startServerProcessHandler.addProcessListener(
getProcessListener(params.port(), onSuccess, onServerTerminated));
getProcessListener(params.port(), onSuccess,
onServerTerminated));
startServerProcessHandler.startNotify();
} catch (ExecutionException ex) {
LOG.error("Unable to start llama server", ex);
throw new RuntimeException(ex);
showServerError(ex.getMessage(), onServerTerminated);
}
}
};
@ -99,27 +131,25 @@ public final class LlamaServerAgent implements Disposable {
private ProcessListener getProcessListener(
int port,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
return new ProcessAdapter() {
private final ObjectMapper objectMapper = new ObjectMapper();
private final List<String> errorLines = new CopyOnWriteArrayList<>();
@Override
public void processTerminated(@NotNull ProcessEvent event) {
if (errorLines.isEmpty()) {
LOG.info(format("Server terminated with code %d", event.getExitCode()));
LOG.info(format("Server terminated with code %d", event.getExitCode()));
if (stoppedByUser) {
onServerTerminated.accept(activeServerProgressPanel);
} else {
LOG.info(String.join("", errorLines));
showServerError(String.join(",", errorLines), onServerTerminated);
}
onServerTerminated.run();
}
@Override
public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) {
if (ProcessOutputType.isStderr(outputType)) {
errorLines.add(event.getText());
return;
}
if (ProcessOutputType.isStdout(outputType)) {
@ -141,11 +171,18 @@ public final class LlamaServerAgent implements Disposable {
};
}
private static GeneralCommandLine getMakeCommandLinde() {
private void showServerError(String errorText, Consumer<ServerProgressPanel> onServerTerminated) {
onServerTerminated.accept(activeServerProgressPanel);
LOG.info("Unable to start llama server:\n" + errorText);
OverlayUtil.showClosableBalloon(errorText, MessageType.ERROR, activeServerProgressPanel);
}
private static GeneralCommandLine getMakeCommandLine(List<String> additionalCompileParameters) {
GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8);
commandLine.setExePath("make");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
commandLine.addParameters("-j");
commandLine.addParameters(additionalCompileParameters);
commandLine.setRedirectErrorStream(false);
return commandLine;
}
@ -159,11 +196,16 @@ public final class LlamaServerAgent implements Disposable {
"-c", String.valueOf(params.contextLength()),
"--port", String.valueOf(params.port()),
"-t", String.valueOf(params.threads()));
commandLine.addParameters(params.additionalParameters());
commandLine.addParameters(params.additionalRunParameters());
commandLine.setRedirectErrorStream(false);
return commandLine;
}
public void setActiveServerProgressPanel(
ServerProgressPanel activeServerProgressPanel) {
this.activeServerProgressPanel = activeServerProgressPanel;
}
@Override
public void dispose() {
if (makeProcessHandler != null && !makeProcessHandler.isProcessTerminated()) {

View file

@ -3,5 +3,6 @@ package ee.carlrobert.codegpt.completions.llama;
import java.util.List;
public record LlamaServerStartupParams(String modelPath, int contextLength, int threads, int port,
List<String> additionalParameters) {
List<String> additionalRunParameters,
List<String> additionalBuildParameters) {
}

View file

@ -23,6 +23,7 @@ public class LlamaSettingsState {
private int contextSize = 2048;
private int threads = 8;
private String additionalParameters = "";
private String additionalBuildParameters = "";
private int topK = 40;
private double topP = 0.9;
private double minP = 0.05;
@ -138,6 +139,14 @@ public class LlamaSettingsState {
this.additionalParameters = additionalParameters;
}
public String getAdditionalBuildParameters() {
return additionalBuildParameters;
}
public void setAdditionalBuildParameters(String additionalBuildParameters) {
this.additionalBuildParameters = additionalBuildParameters;
}
public int getTopK() {
return topK;
}
@ -220,6 +229,7 @@ public class LlamaSettingsState {
&& Objects.equals(baseHost, that.baseHost)
&& Objects.equals(serverPort, that.serverPort)
&& Objects.equals(additionalParameters, that.additionalParameters)
&& Objects.equals(additionalBuildParameters, that.additionalBuildParameters)
&& codeCompletionsEnabled == that.codeCompletionsEnabled
&& codeCompletionMaxTokens == that.codeCompletionMaxTokens;
}
@ -229,7 +239,7 @@ public class LlamaSettingsState {
return Objects.hash(runLocalServer, useCustomModel, customLlamaModelPath, huggingFaceModel,
localModelPromptTemplate, remoteModelPromptTemplate, localModelInfillPromptTemplate,
remoteModelInfillPromptTemplate, baseHost, serverPort, contextSize, threads,
additionalParameters, topK, topP, minP, repeatPenalty, codeCompletionsEnabled,
codeCompletionMaxTokens);
additionalParameters, additionalBuildParameters, topK, topP, minP, repeatPenalty,
codeCompletionsEnabled, codeCompletionMaxTokens);
}
}

View file

@ -57,6 +57,7 @@ public class LlamaServerPreferencesForm {
private final IntegerField maxTokensField;
private final IntegerField threadsField;
private final JBTextField additionalParametersField;
private final JBTextField additionalBuildParametersField;
private final ChatPromptTemplatePanel remotePromptTemplatePanel;
private final InfillPromptTemplatePanel infillPromptTemplatePanel;
@ -79,6 +80,9 @@ public class LlamaServerPreferencesForm {
additionalParametersField = new JBTextField(settings.getAdditionalParameters(), 30);
additionalParametersField.setEnabled(!serverRunning);
additionalBuildParametersField = new JBTextField(settings.getAdditionalBuildParameters(), 30);
additionalBuildParametersField.setEnabled(!serverRunning);
baseHostField = new JBTextField(settings.getBaseHost(), 30);
apiKeyField = new JBPasswordField();
apiKeyField.setColumns(30);
@ -124,6 +128,7 @@ public class LlamaServerPreferencesForm {
maxTokensField.setValue(state.getContextSize());
threadsField.setValue(state.getThreads());
additionalParametersField.setText(state.getAdditionalParameters());
additionalBuildParametersField.setText(state.getAdditionalBuildParameters());
remotePromptTemplatePanel.setPromptTemplate(state.getRemoteModelPromptTemplate()); // ?
infillPromptTemplatePanel.setPromptTemplate(state.getRemoteModelInfillPromptTemplate());
apiKeyField.setText(CredentialsStore.INSTANCE.getCredential(LLAMA_API_KEY));
@ -184,9 +189,17 @@ public class LlamaServerPreferencesForm {
createComment("settingsConfigurable.service.llama.threads.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.additionalParameters.label"),
additionalParametersField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.additionalParameters.comment"))
additionalParametersField)
.addComponentToRightColumn(
createComment(
"settingsConfigurable.service.llama.additionalParameters.comment"))
.addLabeledComponent(
CodeGPTBundle.get(
"settingsConfigurable.service.llama.additionalBuildParameters.label"),
additionalBuildParametersField)
.addComponentToRightColumn(
createComment(
"settingsConfigurable.service.llama.additionalBuildParameters.comment"))
.addVerticalGap(4)
.addComponentFillVertically(new JPanel(), 0)
.getPanel()))
@ -196,6 +209,7 @@ public class LlamaServerPreferencesForm {
private JButton getServerButton(
LlamaServerAgent llamaServerAgent,
ServerProgressPanel serverProgressPanel) {
llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel);
var serverRunning = llamaServerAgent.isServerRunning();
var serverButton = new JButton();
serverButton.setText(serverRunning
@ -218,7 +232,9 @@ public class LlamaServerPreferencesForm {
getContextSize(),
getThreads(),
getServerPort(),
getListOfAdditionalParameters()),
getListOfAdditionalParameters(),
getListOfAdditionalBuildParameters()
),
serverProgressPanel,
() -> {
setFormEnabled(false);
@ -227,12 +243,12 @@ public class LlamaServerPreferencesForm {
Actions.Checked,
SwingConstants.LEADING));
},
() -> {
(activeServerProgressPanel) -> {
setFormEnabled(true);
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.startServer.label"));
serverButton.setIcon(Actions.Execute);
serverProgressPanel.displayComponent(new JBLabel(
activeServerProgressPanel.displayComponent(new JBLabel(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverTerminated"),
Actions.Cancel,
SwingConstants.LEADING));
@ -282,7 +298,7 @@ public class LlamaServerPreferencesForm {
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.startServer.label"));
serverButton.setIcon(Actions.Execute);
progressPanel.updateText(
progressPanel.displayText(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.stoppingServer"));
}
@ -291,7 +307,7 @@ public class LlamaServerPreferencesForm {
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.stopServer.label"));
serverButton.setIcon(Actions.Suspend);
progressPanel.startProgress(
progressPanel.displayText(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.startingServer"));
}
@ -301,6 +317,7 @@ public class LlamaServerPreferencesForm {
maxTokensField.setEnabled(enabled);
threadsField.setEnabled(enabled);
additionalParametersField.setEnabled(enabled);
additionalBuildParametersField.setEnabled(enabled);
}
public boolean isRunLocalServer() {
@ -337,9 +354,20 @@ public class LlamaServerPreferencesForm {
public List<String> getListOfAdditionalParameters() {
return Arrays.stream(additionalParametersField.getText().split(","))
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
}
public String getAdditionalBuildParameters() {
return additionalBuildParametersField.getText();
}
public List<String> getListOfAdditionalBuildParameters() {
return Arrays.stream(additionalBuildParametersField.getText().split(","))
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
}
public PromptTemplate getPromptTemplate() {

View file

@ -41,6 +41,7 @@ public class LlamaSettingsForm extends JPanel {
state.setContextSize(llamaServerPreferencesForm.getContextSize());
state.setThreads(llamaServerPreferencesForm.getThreads());
state.setAdditionalParameters(llamaServerPreferencesForm.getAdditionalParameters());
state.setAdditionalBuildParameters(llamaServerPreferencesForm.getAdditionalBuildParameters());
var modelPreferencesForm = llamaServerPreferencesForm.getLlamaModelPreferencesForm();
state.setCustomLlamaModelPath(modelPreferencesForm.getCustomLlamaModelPath());

View file

@ -8,20 +8,15 @@ import javax.swing.JPanel;
public class ServerProgressPanel extends JPanel {
private final JBLabel label = new JBLabel();
private final AsyncProcessIcon loadingSpinner = new AsyncProcessIcon("sign_in_spinner");
public ServerProgressPanel() {
setVisible(false);
add(new AsyncProcessIcon("sign_in_spinner"));
add(label);
}
public void startProgress(String text) {
setVisible(true);
updateText(text);
}
public void updateText(String text) {
public void displayText(String text) {
label.setText(text);
removeAll();
add(loadingSpinner);
add(label);
revalidate();
repaint();
}
public void displayComponent(JComponent component) {

View file

@ -149,4 +149,13 @@ public class OverlayUtil {
.createBalloon()
.show(RelativePoint.getSouthOf(component), Position.below);
}
public static void showClosableBalloon(String content, MessageType messageType,
JComponent component) {
JBPopupFactory.getInstance()
.createHtmlTextBalloonBuilder(content, messageType, null)
.setCloseButtonEnabled(true)
.createBalloon()
.show(RelativePoint.getSouthOf(component), Position.below);
}
}