feat: add field for environment variables for Llama server (#550)

Co-authored-by: Carl-Robert <carlrobertoh@gmail.com>
This commit is contained in:
Phil 2024-05-23 11:55:51 +02:00 committed by GitHub
parent ee6b2d3350
commit 08b592f7e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 65 additions and 10 deletions

View file

@ -50,7 +50,7 @@ public final class LlamaServerAgent implements Disposable {
serverProgressPanel.displayText(
CodeGPTBundle.get("llamaServerAgent.buildingProject.description"));
makeProcessHandler = new OSProcessHandler(
getMakeCommandLine(params.additionalBuildParameters()));
getMakeCommandLine(params));
makeProcessHandler.addProcessListener(
getMakeProcessListener(params, onSuccess, onServerStopped));
makeProcessHandler.startNotify();
@ -177,12 +177,13 @@ public final class LlamaServerAgent implements Disposable {
OverlayUtil.showClosableBalloon(errorText, MessageType.ERROR, activeServerProgressPanel);
}
private static GeneralCommandLine getMakeCommandLine(List<String> additionalCompileParameters) {
private static GeneralCommandLine getMakeCommandLine(LlamaServerStartupParams params) {
GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8);
commandLine.setExePath("make");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
commandLine.addParameters("-j");
commandLine.addParameters(additionalCompileParameters);
commandLine.addParameters(params.additionalBuildParameters());
commandLine.withEnvironment(params.additionalEnvironmentVariables());
commandLine.setRedirectErrorStream(false);
return commandLine;
}
@ -197,6 +198,7 @@ public final class LlamaServerAgent implements Disposable {
"--port", String.valueOf(params.port()),
"-t", String.valueOf(params.threads()));
commandLine.addParameters(params.additionalRunParameters());
commandLine.withEnvironment(params.additionalEnvironmentVariables());
commandLine.setRedirectErrorStream(false);
return commandLine;
}

View file

@ -1,8 +1,10 @@
package ee.carlrobert.codegpt.completions.llama;
import java.util.List;
import java.util.Map;
public record LlamaServerStartupParams(String modelPath, int contextLength, int threads, int port,
List<String> additionalRunParameters,
List<String> additionalBuildParameters) {
List<String> additionalBuildParameters,
Map<String, String> additionalEnvironmentVariables) {
}

View file

@ -20,6 +20,8 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
@ -106,4 +108,13 @@ public class LlamaSettings implements PersistentStateComponent<LlamaSettingsStat
.toList();
}
public static Map<String, String> getAdditionalEnvironmentVariablesMap(
String additionalEnvironmentVariables) {
return Arrays.stream(additionalEnvironmentVariables.split(" "))
.map(String::trim)
.filter(s -> !s.isBlank() && s.contains("="))
.collect(Collectors.toMap(item -> item.split("=")[0].trim(),
item -> item.split("=")[1].trim()));
}
}

View file

@ -24,6 +24,7 @@ public class LlamaSettingsState {
private int threads = 8;
private String additionalParameters = "";
private String additionalBuildParameters = "";
private String additionalEnvironmentVariables = "";
private int topK = 40;
private double topP = 0.9;
private double minP = 0.05;
@ -146,6 +147,14 @@ public class LlamaSettingsState {
this.additionalBuildParameters = additionalBuildParameters;
}
public String getAdditionalEnvironmentVariables() {
return additionalEnvironmentVariables;
}
public void setAdditionalEnvironmentVariables(String additionalEnvironmentVariables) {
this.additionalEnvironmentVariables = additionalEnvironmentVariables;
}
public int getTopK() {
return topK;
}
@ -221,6 +230,7 @@ public class LlamaSettingsState {
&& Objects.equals(serverPort, that.serverPort)
&& Objects.equals(additionalParameters, that.additionalParameters)
&& Objects.equals(additionalBuildParameters, that.additionalBuildParameters)
&& Objects.equals(additionalEnvironmentVariables, that.additionalEnvironmentVariables)
&& codeCompletionsEnabled == that.codeCompletionsEnabled;
}
@ -229,7 +239,8 @@ public class LlamaSettingsState {
return Objects.hash(runLocalServer, useCustomModel, customLlamaModelPath, huggingFaceModel,
localModelPromptTemplate, remoteModelPromptTemplate, localModelInfillPromptTemplate,
remoteModelInfillPromptTemplate, baseHost, serverPort, contextSize, threads,
additionalParameters, additionalBuildParameters, topK, topP, minP, repeatPenalty,
additionalParameters, additionalBuildParameters, additionalEnvironmentVariables, topK, topP,
minP, repeatPenalty,
codeCompletionsEnabled);
}
}

View file

@ -58,6 +58,7 @@ public class LlamaServerPreferencesForm {
private final IntegerField threadsField;
private final JBTextField additionalParametersField;
private final JBTextField additionalBuildParametersField;
private final JBTextField additionalEnvironmentVariablesField;
private final ChatPromptTemplatePanel remotePromptTemplatePanel;
private final InfillPromptTemplatePanel infillPromptTemplatePanel;
@ -83,6 +84,10 @@ public class LlamaServerPreferencesForm {
additionalBuildParametersField = new JBTextField(settings.getAdditionalBuildParameters(), 30);
additionalBuildParametersField.setEnabled(!serverRunning);
additionalEnvironmentVariablesField = new JBTextField(
settings.getAdditionalEnvironmentVariables(), 30);
additionalEnvironmentVariablesField.setEnabled(!serverRunning);
baseHostField = new URLTextField(settings.getBaseHost(), 30);
apiKeyField = new JBPasswordField();
apiKeyField.setColumns(30);
@ -132,6 +137,7 @@ public class LlamaServerPreferencesForm {
threadsField.setValue(state.getThreads());
additionalParametersField.setText(state.getAdditionalParameters());
additionalBuildParametersField.setText(state.getAdditionalBuildParameters());
additionalEnvironmentVariablesField.setText(state.getAdditionalEnvironmentVariables());
remotePromptTemplatePanel.setPromptTemplate(state.getRemoteModelPromptTemplate()); // ?
infillPromptTemplatePanel.setPromptTemplate(state.getRemoteModelInfillPromptTemplate());
apiKeyField.setText(CredentialsStore.getCredential(LLAMA_API_KEY));
@ -204,6 +210,14 @@ public class LlamaServerPreferencesForm {
createComment(
"settingsConfigurable.service.llama.additionalBuildParameters.comment"))
.addVerticalGap(4)
.addLabeledComponent(
CodeGPTBundle.get(
"settingsConfigurable.service.llama.additionalEnvironmentVariables.label"),
additionalEnvironmentVariablesField)
.addComponentToRightColumn(
createComment(
"settingsConfigurable.service.llama.additionalEnvironmentVariables.comment"))
.addVerticalGap(4)
.addComponentFillVertically(new JPanel(), 0)
.getPanel()))
.getPanel());
@ -236,7 +250,8 @@ public class LlamaServerPreferencesForm {
getThreads(),
getServerPort(),
getListOfAdditionalParameters(),
getListOfAdditionalBuildParameters()
getListOfAdditionalBuildParameters(),
getMapOfAdditionalEnvironmentVariables()
),
serverProgressPanel,
() -> {
@ -316,6 +331,7 @@ public class LlamaServerPreferencesForm {
threadsField.setEnabled(enabled);
additionalParametersField.setEnabled(enabled);
additionalBuildParametersField.setEnabled(enabled);
additionalEnvironmentVariablesField.setEnabled(enabled);
}
public boolean isRunLocalServer() {
@ -362,6 +378,15 @@ public class LlamaServerPreferencesForm {
return LlamaSettings.getAdditionalParametersList(additionalBuildParametersField.getText());
}
public String getAdditionalEnvironmentVariables() {
return additionalEnvironmentVariablesField.getText();
}
public Map<String, String> getMapOfAdditionalEnvironmentVariables() {
return LlamaSettings.getAdditionalEnvironmentVariablesMap(
additionalEnvironmentVariablesField.getText());
}
public PromptTemplate getPromptTemplate() {
return isRunLocalServer() ? llamaModelPreferencesForm.getPromptTemplate()
: remotePromptTemplatePanel.getPromptTemplate();

View file

@ -42,6 +42,8 @@ public class LlamaSettingsForm extends JPanel {
state.setThreads(llamaServerPreferencesForm.getThreads());
state.setAdditionalParameters(llamaServerPreferencesForm.getAdditionalParameters());
state.setAdditionalBuildParameters(llamaServerPreferencesForm.getAdditionalBuildParameters());
state.setAdditionalEnvironmentVariables(
llamaServerPreferencesForm.getAdditionalEnvironmentVariables());
var modelPreferencesForm = llamaServerPreferencesForm.getLlamaModelPreferencesForm();
state.setCustomLlamaModelPath(modelPreferencesForm.getCustomLlamaModelPath());