diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1ccefd10..90e3e521 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
+### Added
+
+- Support for custom OpenAI model configuration
+
## [2.3.0] - 2024-02-14
### Added
diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java
index 2070ec53..3644a49a 100644
--- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java
+++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java
@@ -40,9 +40,10 @@ public final class CompletionRequestService {
switch (GeneralSettings.getCurrentState().getSelectedService()) {
case OPENAI:
var openAISettings = OpenAISettings.getCurrentState();
+ var customModel = openAISettings.getCustomModel();
return CompletionClientProvider.getOpenAIClient().getChatCompletionAsync(
requestProvider.buildOpenAIChatCompletionRequest(
- openAISettings.getModel(),
+ customModel.trim().isEmpty() ? openAISettings.getModel() : customModel,
callParameters,
useContextualSearch,
openAISettings.isUsingCustomPath() ? openAISettings.getPath() : null),
diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/openai/OpenAISettingsForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/openai/OpenAISettingsForm.java
index c9860bb0..b42673e9 100644
--- a/src/main/java/ee/carlrobert/codegpt/settings/service/openai/OpenAISettingsForm.java
+++ b/src/main/java/ee/carlrobert/codegpt/settings/service/openai/OpenAISettingsForm.java
@@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.settings.service.openai;
import static ee.carlrobert.codegpt.ui.UIUtil.withEmptyLeftBorder;
import com.intellij.openapi.ui.ComboBox;
+import com.intellij.ui.DocumentAdapter;
import com.intellij.ui.EnumComboBoxModel;
import com.intellij.ui.TitledSeparator;
import com.intellij.ui.components.JBPasswordField;
@@ -13,12 +14,17 @@ import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.credentials.OpenAICredentialManager;
import ee.carlrobert.codegpt.ui.UIUtil;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
+import java.awt.event.ActionEvent;
+import java.awt.event.ActionListener;
import javax.annotation.Nullable;
import javax.swing.JPanel;
+import javax.swing.event.DocumentEvent;
+import org.jetbrains.annotations.NotNull;
public class OpenAISettingsForm {
private final JBPasswordField openAIApiKeyField;
+ private final JBTextField openAICustomModelField;
private final JBTextField openAIBaseHostField;
private final JBTextField openAIPathField;
private final JBTextField openAIOrganizationField;
@@ -33,8 +39,16 @@ public class OpenAISettingsForm {
openAIOrganizationField = new JBTextField(settings.getOrganization(), 30);
openAICompletionModelComboBox = new ComboBox<>(
new EnumComboBoxModel<>(OpenAIChatCompletionModel.class));
+ openAICompletionModelComboBox.setEnabled(settings.getCustomModel().isEmpty());
openAICompletionModelComboBox.setSelectedItem(
OpenAIChatCompletionModel.findByCode(settings.getModel()));
+ openAICustomModelField = new JBTextField(settings.getCustomModel(), 20);
+ openAICustomModelField.getDocument().addDocumentListener(new DocumentAdapter() {
+ @Override
+ protected void textChanged(@NotNull DocumentEvent e) {
+ openAICompletionModelComboBox.setEnabled(openAICustomModelField.getText().isEmpty());
+ }
+ });
}
public JPanel getForm() {
@@ -43,6 +57,10 @@ public class OpenAISettingsForm {
.withLabel(CodeGPTBundle.get(
"settingsConfigurable.shared.model.label"))
.resizeX(false))
+ .add(UI.PanelFactory.panel(openAICustomModelField)
+ .withLabel(CodeGPTBundle.get(
+ "settingsConfigurable.service.openai.customModel.label"))
+ .resizeX(false))
.add(UI.PanelFactory.panel(openAIOrganizationField)
.withLabel(CodeGPTBundle.get(
"settingsConfigurable.service.openai.organization.label"))
@@ -94,6 +112,7 @@ public class OpenAISettingsForm {
state.setOrganization(openAIOrganizationField.getText());
state.setBaseHost(openAIBaseHostField.getText());
state.setPath(openAIPathField.getText());
+ state.setCustomModel(openAICustomModelField.getText());
state.setModel(getModel());
return state;
}
@@ -102,6 +121,7 @@ public class OpenAISettingsForm {
var state = OpenAISettings.getCurrentState();
openAIApiKeyField.setText(OpenAICredentialManager.getInstance().getCredential());
openAIOrganizationField.setText(state.getOrganization());
+ openAICustomModelField.setText(state.getCustomModel());
openAIBaseHostField.setText(state.getBaseHost());
openAIPathField.setText(state.getPath());
openAICompletionModelComboBox.setSelectedItem(
diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/openai/OpenAISettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/service/openai/OpenAISettingsState.java
index e321886f..975e2553 100644
--- a/src/main/java/ee/carlrobert/codegpt/settings/service/openai/OpenAISettingsState.java
+++ b/src/main/java/ee/carlrobert/codegpt/settings/service/openai/OpenAISettingsState.java
@@ -11,6 +11,7 @@ public class OpenAISettingsState {
private String baseHost = "https://api.openai.com";
private String path = BASE_PATH;
private String model = OpenAIChatCompletionModel.GPT_3_5.getCode();
+ private String customModel = "";
public boolean isUsingCustomPath() {
return !BASE_PATH.equals(path);
@@ -32,6 +33,14 @@ public class OpenAISettingsState {
this.baseHost = openAIBaseHost;
}
+ public String getPath() {
+ return path;
+ }
+
+ public void setPath(String path) {
+ this.path = path;
+ }
+
public String getModel() {
return model;
}
@@ -40,12 +49,12 @@ public class OpenAISettingsState {
this.model = model;
}
- public String getPath() {
- return path;
+ public String getCustomModel() {
+ return customModel;
}
- public void setPath(String path) {
- this.path = path;
+ public void setCustomModel(String customModel) {
+ this.customModel = customModel;
}
@Override
@@ -60,11 +69,12 @@ public class OpenAISettingsState {
return Objects.equals(organization, that.organization)
&& Objects.equals(baseHost, that.baseHost)
&& Objects.equals(path, that.path)
- && Objects.equals(model, that.model);
+ && Objects.equals(model, that.model)
+ && Objects.equals(customModel, that.customModel);
}
@Override
public int hashCode() {
- return Objects.hash(organization, baseHost, path, model);
+ return Objects.hash(organization, baseHost, path, model, customModel);
}
}
diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/ModelComboBoxAction.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/ModelComboBoxAction.java
index dbf6dbf6..f6df91cc 100644
--- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/ModelComboBoxAction.java
+++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/ModelComboBoxAction.java
@@ -56,13 +56,22 @@ public class ModelComboBoxAction extends ComboBoxAction {
var presentation = ((ComboBoxButton) button).getPresentation();
var actionGroup = new DefaultActionGroup();
actionGroup.addSeparator("OpenAI");
- List.of(
- OpenAIChatCompletionModel.GPT_4_0125_128k,
- OpenAIChatCompletionModel.GPT_3_5_0125_16k,
- OpenAIChatCompletionModel.GPT_4_32k,
- OpenAIChatCompletionModel.GPT_4,
- OpenAIChatCompletionModel.GPT_3_5)
- .forEach(model -> actionGroup.add(createOpenAIModelAction(model, presentation)));
+ var settings = OpenAISettings.getCurrentState();
+ if (settings.getCustomModel().isEmpty()) {
+ List.of(
+ OpenAIChatCompletionModel.GPT_4_0125_128k,
+ OpenAIChatCompletionModel.GPT_3_5_0125_16k,
+ OpenAIChatCompletionModel.GPT_4_32k,
+ OpenAIChatCompletionModel.GPT_4,
+ OpenAIChatCompletionModel.GPT_3_5)
+ .forEach(model -> actionGroup.add(createOpenAIModelAction(model, presentation)));
+ } else {
+ actionGroup.add(createModelAction(
+ ServiceType.OPENAI,
+ settings.getCustomModel(),
+ Icons.OpenAI,
+ presentation));
+ }
actionGroup.addSeparator();
actionGroup.add(
createModelAction(ServiceType.AZURE, "Azure OpenAI", Icons.Azure, presentation));
@@ -87,8 +96,7 @@ public class ModelComboBoxAction extends ComboBoxAction {
switch (selectedService) {
case OPENAI:
templatePresentation.setIcon(Icons.OpenAI);
- templatePresentation.setText(
- OpenAIChatCompletionModel.findByCode(openAISettings.getModel()).getDescription());
+ templatePresentation.setText(getOpenAiPresentationText());
break;
case AZURE:
templatePresentation.setIcon(Icons.Azure);
@@ -106,6 +114,14 @@ public class ModelComboBoxAction extends ComboBoxAction {
}
}
+ private String getOpenAiPresentationText() {
+ var settings = OpenAISettings.getCurrentState();
+ if (settings.getCustomModel().isEmpty()) {
+ return OpenAIChatCompletionModel.findByCode(openAISettings.getModel()).getDescription();
+ }
+ return settings.getCustomModel();
+ }
+
private String getLlamaCppPresentationText() {
var llamaSettingState = LlamaSettings.getCurrentState();
if (!llamaSettingState.isRunLocalServer()) {
diff --git a/src/main/resources/messages/codegpt.properties b/src/main/resources/messages/codegpt.properties
index 358a6c67..0cf3c608 100644
--- a/src/main/resources/messages/codegpt.properties
+++ b/src/main/resources/messages/codegpt.properties
@@ -18,6 +18,7 @@ settings.openaiQuotaExceeded=OpenAI quota exceeded.
settingsConfigurable.displayName.label=Display name:
settingsConfigurable.service.label=Service:
settingsConfigurable.service.openai.apiKey.comment=You can find your Secret API key in your User settings.
+settingsConfigurable.service.openai.customModel.label=Custom model:
settingsConfigurable.service.openai.organization.label=Organization:
settingsConfigurable.section.openai.organization.comment=Useful when you are part of multiple organizations optional
settingsConfigurable.service.azure.resourceName.label=Resource name: