Added option to set the number of threads for local LLM models (#282)

* Added option to set the number of threads for local LLM models

* Refactoring

---------

Co-authored-by: Viktor <viktor.hoshyi@gg4l.com>
This commit is contained in:
Viktor 2023-11-20 22:36:19 +00:00 committed by GitHub
parent fc6d085b61
commit 73870cca40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 4 deletions

View file

@ -34,6 +34,7 @@ public final class LlamaServerAgent implements Disposable {
public void startAgent(
String modelPath,
int contextLength,
int threads,
int port,
ServerProgressPanel serverProgressPanel,
Runnable onSuccess) {
@ -42,7 +43,13 @@ public final class LlamaServerAgent implements Disposable {
serverProgressPanel.updateText("Building llama.cpp...");
makeProcessHandler = new OSProcessHandler(getMakeCommandLinde());
makeProcessHandler.addProcessListener(
getMakeProcessListener(modelPath, contextLength, port, serverProgressPanel, onSuccess));
getMakeProcessListener(
modelPath,
contextLength,
threads,
port,
serverProgressPanel,
onSuccess));
makeProcessHandler.startNotify();
} catch (ExecutionException e) {
throw new RuntimeException(e);
@ -65,6 +72,7 @@ public final class LlamaServerAgent implements Disposable {
private ProcessListener getMakeProcessListener(
String modelPath,
int contextLength,
int threads,
int port,
ServerProgressPanel serverProgressPanel,
Runnable onSuccess) {
@ -79,7 +87,7 @@ public final class LlamaServerAgent implements Disposable {
try {
serverProgressPanel.updateText("Booting up server...");
startServerProcessHandler = new OSProcessHandler(
getServerCommandLine(modelPath, contextLength, port));
getServerCommandLine(modelPath, contextLength, threads, port));
startServerProcessHandler.addProcessListener(
getProcessListener(port, serverProgressPanel, onSuccess));
startServerProcessHandler.startNotify();
@ -133,14 +141,19 @@ public final class LlamaServerAgent implements Disposable {
return commandLine;
}
private GeneralCommandLine getServerCommandLine(String modelPath, int contextLength, int port) {
private GeneralCommandLine getServerCommandLine(
String modelPath,
int contextLength,
int threads,
int port) {
GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8);
commandLine.setExePath("./server");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
commandLine.addParameters(
"-m", modelPath,
"-c", String.valueOf(contextLength),
"--port", String.valueOf(port));
"--port", String.valueOf(port),
"-t", String.valueOf(threads));
commandLine.setRedirectErrorStream(false);
return commandLine;
}

View file

@ -64,6 +64,7 @@ public class SettingsConfigurable implements Configurable {
|| llamaSettings.isUseCustomModel() != llamaModelPreferencesForm.isUseCustomLlamaModel()
|| llamaSettings.getServerPort() != serviceSelectionForm.getLlamaServerPort()
|| llamaSettings.getContextSize() != serviceSelectionForm.getContextSize()
|| llamaSettings.getThreads() != serviceSelectionForm.getThreads()
|| llamaSettings.getHuggingFaceModel() != llamaModelPreferencesForm.getSelectedModel()
|| !llamaSettings.getPromptTemplate().equals(llamaModelPreferencesForm.getPromptTemplate())
|| !llamaSettings.getCustomLlamaModelPath()
@ -96,6 +97,7 @@ public class SettingsConfigurable implements Configurable {
llamaSettings.setPromptTemplate(llamaModelPreferencesForm.getPromptTemplate());
llamaSettings.setServerPort(serviceSelectionForm.getLlamaServerPort());
llamaSettings.setContextSize(serviceSelectionForm.getContextSize());
llamaSettings.setThreads(serviceSelectionForm.getThreads());
var azureSettings = AzureSettingsState.getInstance();
var openAISettings = OpenAISettingsState.getInstance();
@ -133,6 +135,7 @@ public class SettingsConfigurable implements Configurable {
llamaModelPreferencesForm.setPromptTemplate(llamaSettings.getPromptTemplate());
serviceSelectionForm.setLlamaServerPort(llamaSettings.getServerPort());
serviceSelectionForm.setContextSize(llamaSettings.getContextSize());
serviceSelectionForm.setThreads(llamaSettings.getThreads());
OpenAISettingsState.getInstance().reset(serviceSelectionForm);
AzureSettingsState.getInstance().reset(serviceSelectionForm);

View file

@ -29,6 +29,7 @@ public class LlamaServiceSelectionForm extends JPanel {
private final LlamaModelPreferencesForm llamaModelPreferencesForm;
private final PortField portField;
private final IntegerField maxTokensField;
private final IntegerField threadsField;
public LlamaServiceSelectionForm() {
var llamaServerAgent =
@ -44,12 +45,20 @@ public class LlamaServiceSelectionForm extends JPanel {
maxTokensField.setValue(2048);
maxTokensField.setEnabled(!serverRunning);
threadsField = new IntegerField("threads", 1, 256);
threadsField.setColumns(12);
threadsField.setValue(8);
threadsField.setEnabled(!serverRunning);
var serverProgressPanel = new ServerProgressPanel();
var serverButton = getServerButton(serverRunning, llamaServerAgent, serverProgressPanel);
var contextSizeHelpText = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.contextSize.comment"),
true);
contextSizeHelpText.setBorder(JBUI.Borders.empty(0, 4));
var threadsHelpText = ComponentPanelBuilder.createCommentComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.threads.comment"),
true);
setLayout(new BorderLayout());
add(FormBuilder.createFormBuilder()
@ -63,6 +72,10 @@ public class LlamaServiceSelectionForm extends JPanel {
CodeGPTBundle.get("settingsConfigurable.service.llama.contextSize.label"),
maxTokensField)
.addComponentToRightColumn(contextSizeHelpText)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.threads.label"),
threadsField)
.addComponentToRightColumn(threadsHelpText)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.port.label"),
JBUI.Panels.simplePanel()
@ -100,6 +113,14 @@ public class LlamaServiceSelectionForm extends JPanel {
maxTokensField.setValue(contextSize);
}
public void setThreads(int threads) {
threadsField.setValue(threads);
}
public int getThreads() {
return threadsField.getValue();
}
private JButton getServerButton(boolean serverRunning, LlamaServerAgent llamaServerAgent,
ServerProgressPanel serverProgressPanel) {
var serverButton = new JButton();
@ -153,6 +174,7 @@ public class LlamaServiceSelectionForm extends JPanel {
llamaServerAgent.startAgent(
modelPath,
maxTokensField.getValue(),
threadsField.getValue(),
portField.getNumber(),
serverProgressPanel,
() -> {
@ -174,5 +196,6 @@ public class LlamaServiceSelectionForm extends JPanel {
llamaModelPreferencesForm.enableFields(enabled);
portField.setEnabled(enabled);
maxTokensField.setEnabled(enabled);
threadsField.setEnabled(enabled);
}
}

View file

@ -412,4 +412,12 @@ public class ServiceSelectionForm {
public void setContextSize(int contextSize) {
llamaServiceSectionPanel.setContextSize(contextSize);
}
public int getThreads() {
return llamaServiceSectionPanel.getThreads();
}
public void setThreads(int threads) {
llamaServiceSectionPanel.setThreads(threads);
}
}

View file

@ -21,6 +21,8 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
private Integer serverPort = getRandomAvailablePortOrDefault();
private int contextSize = 2048;
private int threads = 8;
public LlamaSettingsState() {
}
@ -86,6 +88,14 @@ public class LlamaSettingsState implements PersistentStateComponent<LlamaSetting
this.contextSize = contextSize;
}
public int getThreads() {
return threads;
}
public void setThreads(int threads) {
this.threads = threads;
}
private static Integer getRandomAvailablePortOrDefault() {
try (ServerSocket socket = new ServerSocket(0)) {
return socket.getLocalPort();