feat: Implement Ollama as a high-level service (#510)

* Initial implementation of Ollama as a service

* Fix model selector in tool window

* Enable image attachment

* Rewrite OllamaSettingsForm in Kt

* Create OllamaInlineCompletionModel and use it for building completion template

* Add support for blocking code completion on models that we don't know support it

* Allow disabling code completion settings

* Disable code completion settings when an unsupported model is entered

* Track FIM template in settings as a derived state

* Update llm-client

* Initial implementation of model combo box

* Add Ollama icon and display models as list

* Make OllamaSettingsState immutable & convert OllamaSettings to Kotlin

* Add refresh models button

* Distinguish between empty/needs refresh/loading

* Avoid storing any model if the combo box is empty

* Fix icon size

* Back to mutable settings
There were some bugs with immutable settings

* Store available models in settings state

* Expose available models in model dropdown

* Add dark icon

* Cleanups for CompletionRequestProvider

* Fix checkstyle issues

* refactor: migrate to SimplePersistentStateComponent

* fix: add code completion stop tokens

* fix: display only one item in the model popup action group

* fix: add back multi model selection

---------

Co-authored-by: Carl-Robert Linnupuu <carlrobertoh@gmail.com>
This commit is contained in:
Jack Boswell 2024-05-08 10:11:13 +12:00 committed by GitHub
parent 7f7b35d3be
commit e40630d796
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 505 additions and 39 deletions

View file

@ -16,6 +16,7 @@ public final class Icons {
public static final Icon Sparkle = IconLoader.getIcon("/icons/sparkle.svg", Icons.class);
public static final Icon You = IconLoader.getIcon("/icons/you.svg", Icons.class);
public static final Icon YouSmall = IconLoader.getIcon("/icons/you_small.png", Icons.class);
public static final Icon Ollama = IconLoader.getIcon("/icons/ollama.svg", Icons.class);
public static final Icon User = IconLoader.getIcon("/icons/user.svg", Icons.class);
public static final Icon Upload = IconLoader.getIcon("/icons/upload.svg", Icons.class);
}

View file

@ -1,5 +1,6 @@
package ee.carlrobert.codegpt.completions;
import com.intellij.openapi.application.ApplicationManager;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.completions.you.YouUserManager;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
@ -8,11 +9,13 @@ import ee.carlrobert.codegpt.settings.advanced.AdvancedSettings;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.anthropic.ClaudeClient;
import ee.carlrobert.llm.client.azure.AzureClient;
import ee.carlrobert.llm.client.azure.AzureCompletionRequestParams;
import ee.carlrobert.llm.client.llama.LlamaClient;
import ee.carlrobert.llm.client.ollama.OllamaClient;
import ee.carlrobert.llm.client.openai.OpenAIClient;
import ee.carlrobert.llm.client.you.UTMParameters;
import ee.carlrobert.llm.client.you.YouClient;
@ -92,6 +95,16 @@ public class CompletionClientProvider {
return builder.build(getDefaultClientBuilder());
}
public static OllamaClient getOllamaClient() {
var host = ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getHost();
return new OllamaClient.Builder()
.setHost(host)
.build(getDefaultClientBuilder());
}
public static OkHttpClient.Builder getDefaultClientBuilder() {
OkHttpClient.Builder builder = new OkHttpClient.Builder();
var advancedSettings = AdvancedSettings.getCurrentState();

View file

@ -26,8 +26,8 @@ import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceState;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
import ee.carlrobert.codegpt.telemetry.core.configuration.TelemetryConfiguration;
@ -41,6 +41,8 @@ import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMes
import ee.carlrobert.llm.client.anthropic.completion.ClaudeMessageImageContent;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeMessageTextContent;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionDetailedMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
@ -56,6 +58,7 @@ import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
@ -140,7 +143,8 @@ public class CompletionRequestProvider {
public static Request buildCustomOpenAILookupCompletionRequest(String context) {
return buildCustomOpenAIChatCompletionRequest(
ApplicationManager.getApplication().getService(CustomServiceState.class)
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage(
@ -210,7 +214,7 @@ public class CompletionRequestProvider {
@Nullable String model,
CallParameters callParameters) {
var configuration = ConfigurationSettings.getCurrentState();
return new OpenAIChatCompletionRequest.Builder(buildMessages(model, callParameters))
return new OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
.setModel(model)
.setMaxTokens(configuration.getMaxTokens())
.setStream(true)
@ -222,7 +226,7 @@ public class CompletionRequestProvider {
CallParameters callParameters) {
return buildCustomOpenAIChatCompletionRequest(
settings,
buildMessages(callParameters),
buildOpenAIMessages(callParameters),
true);
}
@ -307,7 +311,68 @@ public class CompletionRequestProvider {
return request;
}
private List<OpenAIChatCompletionMessage> buildMessages(CallParameters callParameters) {
public OllamaChatCompletionRequest buildOllamaChatCompletionRequest(
CallParameters callParameters
) {
var settings = ApplicationManager.getApplication().getService(OllamaSettings.class).getState();
return new OllamaChatCompletionRequest
.Builder(settings.getModel(), buildOllamaMessages(callParameters))
.build();
}
private List<OllamaChatCompletionMessage> buildOllamaMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OllamaChatCompletionMessage>();
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
String systemPrompt = ConfigurationSettings.getCurrentState().getSystemPrompt();
messages.add(new OllamaChatCompletionMessage("system", systemPrompt, null));
}
if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) {
messages.add(
new OllamaChatCompletionMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT, null)
);
}
for (var prevMessage : conversation.getMessages()) {
if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) {
break;
}
var prevMessageImageFilePath = prevMessage.getImageFilePath();
if (prevMessageImageFilePath != null && !prevMessageImageFilePath.isEmpty()) {
try {
var imageFilePath = Path.of(prevMessageImageFilePath);
var imageBytes = Files.readAllBytes(imageFilePath);
var imageBase64 = Base64.getEncoder().encodeToString(imageBytes);
messages.add(
new OllamaChatCompletionMessage(
"user", prevMessage.getPrompt(), List.of(imageBase64)
)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
} else {
messages.add(
new OllamaChatCompletionMessage("user", prevMessage.getPrompt(), null)
);
}
messages.add(
new OllamaChatCompletionMessage("assistant", prevMessage.getResponse(), null)
);
}
if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) {
var imageBase64 = Base64.getEncoder().encodeToString(callParameters.getImageData());
messages.add(
new OllamaChatCompletionMessage("user", message.getPrompt(), List.of(imageBase64))
);
} else {
messages.add(new OllamaChatCompletionMessage("user", message.getPrompt(), null));
}
return messages;
}
private List<OpenAIChatCompletionMessage> buildOpenAIMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OpenAIChatCompletionMessage>();
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
@ -339,7 +404,9 @@ public class CompletionRequestProvider {
} else {
messages.add(new OpenAIChatCompletionStandardMessage("user", prevMessage.getPrompt()));
}
messages.add(new OpenAIChatCompletionStandardMessage("assistant", prevMessage.getResponse()));
messages.add(
new OpenAIChatCompletionStandardMessage("assistant", prevMessage.getResponse())
);
}
if (callParameters.getImageMediaType() != null && callParameters.getImageData().length > 0) {
@ -355,10 +422,10 @@ public class CompletionRequestProvider {
return messages;
}
private List<OpenAIChatCompletionMessage> buildMessages(
private List<OpenAIChatCompletionMessage> buildOpenAIMessages(
@Nullable String model,
CallParameters callParameters) {
var messages = buildMessages(callParameters);
var messages = buildOpenAIMessages(callParameters);
if (model == null
|| GeneralSettings.getCurrentState().getSelectedService() == ServiceType.YOU) {

View file

@ -21,11 +21,14 @@ import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMessage;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
@ -104,6 +107,9 @@ public final class CompletionRequestService {
callParameters.getMessage(),
callParameters.getConversationType()),
eventListener);
case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync(
requestProvider.buildOllamaChatCompletionRequest(callParameters),
eventListener);
};
}
@ -123,6 +129,9 @@ public final class CompletionRequestService {
.getInfillAsync(
CodeCompletionRequestFactory.buildLlamaRequest(requestDetails),
eventListener);
case OLLAMA -> CompletionClientProvider.getOllamaClient().getCompletionAsync(
CodeCompletionRequestFactory.INSTANCE.buildOllamaRequest(requestDetails),
eventListener);
default ->
throw new IllegalArgumentException("Code completion not supported for selected service");
};
@ -189,6 +198,20 @@ public final class CompletionRequestService {
.setRepeat_penalty(settings.getRepeatPenalty())
.build(), eventListener);
break;
case OLLAMA:
var model = ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getModel();
var request = new OllamaChatCompletionRequest.Builder(
model,
List.of(
new OllamaChatCompletionMessage("system", systemPrompt, null),
new OllamaChatCompletionMessage("user", gitDiff, null)
)
).build();
CompletionClientProvider.getOllamaClient().getChatCompletionAsync(request, eventListener);
break;
default:
LOG.debug("Unknown service: {}", selectedService);
break;
@ -228,9 +251,9 @@ public final class CompletionRequestService {
case OPENAI -> CredentialsStore.INSTANCE.isCredentialSet(CredentialKey.OPENAI_API_KEY);
case AZURE -> CredentialsStore.INSTANCE.isCredentialSet(
AzureSettings.getCurrentState().isUseAzureApiKeyAuthentication()
? CredentialKey.AZURE_OPENAI_API_KEY
: CredentialKey.AZURE_ACTIVE_DIRECTORY_TOKEN);
case CUSTOM_OPENAI, ANTHROPIC, LLAMA_CPP -> true;
? CredentialKey.AZURE_OPENAI_API_KEY
: CredentialKey.AZURE_ACTIVE_DIRECTORY_TOKEN);
case CUSTOM_OPENAI, ANTHROPIC, LLAMA_CPP, OLLAMA -> true;
case YOU -> false;
};
}

View file

@ -9,6 +9,7 @@ import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import java.time.LocalDateTime;
import java.util.ArrayList;
@ -195,9 +196,13 @@ public final class ConversationService {
case LLAMA_CPP -> {
var llamaSettings = LlamaSettings.getCurrentState();
yield llamaSettings.isUseCustomModel()
? llamaSettings.getCustomLlamaModelPath()
: llamaSettings.getHuggingFaceModel().getCode();
? llamaSettings.getCustomLlamaModelPath()
: llamaSettings.getHuggingFaceModel().getCode();
}
case OLLAMA -> ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getModel();
};
}
}

View file

@ -11,6 +11,7 @@ import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import org.jetbrains.annotations.NotNull;
@ -69,6 +70,9 @@ public class GeneralSettings implements PersistentStateComponent<GeneralSettings
if ("you.chat.completion".equals(clientCode)) {
state.setSelectedService(ServiceType.YOU);
}
if ("ollama.chat.completion".equals(clientCode)) {
state.setSelectedService(ServiceType.OLLAMA);
}
}
public String getModel() {
@ -98,6 +102,11 @@ public class GeneralSettings implements PersistentStateComponent<GeneralSettings
llamaModel.getLabel(),
huggingFaceModel.getParameterSize(),
huggingFaceModel.getQuantization());
case OLLAMA:
return ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getModel();
default:
return "Unknown";
}

View file

@ -4,6 +4,7 @@ import static ee.carlrobert.codegpt.settings.service.ServiceType.ANTHROPIC;
import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OLLAMA;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;
@ -20,6 +21,8 @@ import ee.carlrobert.codegpt.settings.service.azure.AzureSettingsForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceForm;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsForm;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
@ -45,6 +48,7 @@ public class GeneralSettingsComponent {
private final AzureSettingsForm azureSettingsForm;
private final YouSettingsForm youSettingsForm;
private final LlamaSettingsForm llamaSettingsForm;
private final OllamaSettingsForm ollamaSettingsForm;
public GeneralSettingsComponent(Disposable parentDisposable, GeneralSettings settings) {
displayNameField = new JBTextField(settings.getState().getDisplayName(), 20);
@ -54,6 +58,7 @@ public class GeneralSettingsComponent {
azureSettingsForm = new AzureSettingsForm(AzureSettings.getCurrentState());
youSettingsForm = new YouSettingsForm(YouSettings.getCurrentState(), parentDisposable);
llamaSettingsForm = new LlamaSettingsForm(LlamaSettings.getCurrentState());
ollamaSettingsForm = new OllamaSettingsForm();
var cardLayout = new DynamicCardLayout();
var cards = new JPanel(cardLayout);
@ -63,6 +68,7 @@ public class GeneralSettingsComponent {
cards.add(azureSettingsForm.getForm(), AZURE.getCode());
cards.add(youSettingsForm, YOU.getCode());
cards.add(llamaSettingsForm, LLAMA_CPP.getCode());
cards.add(ollamaSettingsForm.getForm(), OLLAMA.getCode());
var serviceComboBoxModel = new DefaultComboBoxModel<ServiceType>();
serviceComboBoxModel.addAll(Arrays.stream(ServiceType.values()).toList());
serviceComboBox = new ComboBox<>(serviceComboBoxModel);
@ -106,6 +112,10 @@ public class GeneralSettingsComponent {
return youSettingsForm;
}
public OllamaSettingsForm getOllamaSettingsForm() {
return ollamaSettingsForm;
}
public ServiceType getSelectedService() {
return serviceComboBox.getItem();
}
@ -137,6 +147,7 @@ public class GeneralSettingsComponent {
azureSettingsForm.resetForm();
youSettingsForm.resetForm();
llamaSettingsForm.resetForm();
ollamaSettingsForm.resetForm();
}
static class DynamicCardLayout extends CardLayout {

View file

@ -20,6 +20,8 @@ import ee.carlrobert.codegpt.settings.service.azure.AzureSettingsForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceForm;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsForm;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
@ -68,7 +70,8 @@ public class GeneralSettingsConfigurable implements Configurable {
|| AnthropicSettings.getInstance().isModified(component.getAnthropicSettingsForm())
|| AzureSettings.getInstance().isModified(component.getAzureSettingsForm())
|| YouSettings.getInstance().isModified(component.getYouSettingsForm())
|| LlamaSettings.getInstance().isModified(component.getLlamaSettingsForm());
|| LlamaSettings.getInstance().isModified(component.getLlamaSettingsForm())
|| component.getOllamaSettingsForm().isModified();
}
@Override
@ -84,6 +87,7 @@ public class GeneralSettingsConfigurable implements Configurable {
applyAzureSettings(component.getAzureSettingsForm());
applyYouSettings(component.getYouSettingsForm());
applyLlamaSettings(component.getLlamaSettingsForm());
component.getOllamaSettingsForm().applyChanges();
var serviceChanged = component.getSelectedService() != settings.getSelectedService();
var modelChanged = !OpenAISettings.getCurrentState().getModel()
@ -133,6 +137,10 @@ public class GeneralSettingsConfigurable implements Configurable {
form.getActiveDirectoryToken());
}
private void applyOllamaSettings(OllamaSettingsForm form) {
form.applyChanges();
}
@Override
public void reset() {
var settings = GeneralSettings.getCurrentState();

View file

@ -8,7 +8,8 @@ public enum ServiceType {
ANTHROPIC("ANTHROPIC", "service.anthropic.title", "anthropic.chat.completion"),
AZURE("AZURE", "service.azure.title", "azure.chat.completion"),
YOU("YOU", "service.you.title", "you.chat.completion"),
LLAMA_CPP("LLAMA_CPP", "service.llama.title", "llama.chat.completion");
LLAMA_CPP("LLAMA_CPP", "service.llama.title", "llama.chat.completion"),
OLLAMA("OLLAMA", "service.ollama.title", "ollama.chat.completion");
private final String code;
private final String label;

View file

@ -22,7 +22,8 @@ public class LlamaSettingsForm extends JPanel {
llamaRequestPreferencesForm = new LlamaRequestPreferencesForm(settings);
codeCompletionConfigurationForm = new CodeCompletionConfigurationForm(
settings.isCodeCompletionsEnabled(),
settings.getCodeCompletionMaxTokens());
settings.getCodeCompletionMaxTokens(),
null);
init();
}

View file

@ -36,7 +36,8 @@ public class OpenAISettingsForm {
OpenAIChatCompletionModel.findByCode(settings.getModel()));
codeCompletionConfigurationForm = new CodeCompletionConfigurationForm(
settings.isCodeCompletionsEnabled(),
settings.getCodeCompletionMaxTokens());
settings.getCodeCompletionMaxTokens(),
null);
}
public JPanel getForm() {

View file

@ -1,6 +1,7 @@
package ee.carlrobert.codegpt.toolwindow.chat.ui.textarea;
import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OLLAMA;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;
import static java.lang.String.format;
@ -23,6 +24,8 @@ import ee.carlrobert.codegpt.settings.GeneralSettingsState;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettingsState;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsState;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
@ -41,12 +44,16 @@ public class ModelComboBoxAction extends ComboBoxAction {
private final GeneralSettingsState settings;
private final OpenAISettingsState openAISettings;
private final YouSettingsState youSettings;
private final OllamaSettingsState ollamaSettings;
public ModelComboBoxAction(Runnable onModelChange, ServiceType selectedService) {
this.onModelChange = onModelChange;
settings = GeneralSettings.getCurrentState();
openAISettings = OpenAISettings.getCurrentState();
youSettings = YouSettings.getCurrentState();
ollamaSettings = ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState();
updateTemplatePresentation(selectedService);
subscribeToYouSignedOutTopic(ApplicationManager.getApplication().getMessageBus().connect());
@ -103,6 +110,9 @@ public class ModelComboBoxAction extends ComboBoxAction {
getLlamaCppPresentationText(),
Icons.Llama,
presentation));
actionGroup.addSeparator("Ollama");
ollamaSettings.getAvailableModels().forEach(model ->
actionGroup.add(createOllamaModelAction(model, presentation)));
if (YouUserManager.getInstance().isSubscribed()) {
actionGroup.addSeparator("You.com");
@ -179,7 +189,12 @@ public class ModelComboBoxAction extends ComboBoxAction {
templatePresentation.setText(getLlamaCppPresentationText());
templatePresentation.setIcon(Icons.Llama);
break;
case OLLAMA:
templatePresentation.setIcon(Icons.Ollama);
templatePresentation.setText(ollamaSettings.getModel());
break;
default:
break;
}
}
@ -235,6 +250,34 @@ public class ModelComboBoxAction extends ComboBoxAction {
onModelChange.run();
}
private AnAction createOllamaModelAction(
String model,
Presentation comboBoxPresentation
) {
return new DumbAwareAction(model, "", Icons.Ollama) {
@Override
public void update(@NotNull AnActionEvent event) {
var presentation = event.getPresentation();
presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText()));
}
@Override
public void actionPerformed(@NotNull AnActionEvent e) {
ollamaSettings.setModel(model);
handleModelChange(
OLLAMA,
model,
Icons.Ollama,
comboBoxPresentation);
}
@Override
public @NotNull ActionUpdateThread getActionUpdateThread() {
return ActionUpdateThread.BGT;
}
};
}
private AnAction createOpenAIModelAction(
OpenAIChatCompletionModel model,
Presentation comboBoxPresentation) {

View file

@ -1,6 +1,7 @@
package ee.carlrobert.codegpt.toolwindow.chat.ui.textarea;
import static ee.carlrobert.codegpt.settings.service.ServiceType.ANTHROPIC;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OLLAMA;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel.GPT_4_VISION_PREVIEW;
@ -192,6 +193,7 @@ public class UserPromptTextArea extends JPanel {
}));
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
if (selectedService == ANTHROPIC
|| selectedService == OLLAMA
|| (selectedService == OPENAI
&& GPT_4_VISION_PREVIEW.getCode().equals(OpenAISettings.getCurrentState().getModel()))) {
iconsPanel.add(new IconActionButton(new AttachImageAction()));