Use enum value to store selected service (#265)

This commit is contained in:
Carl-Robert 2023-11-08 19:17:25 +02:00 committed by GitHub
parent ff60d1eab5
commit cfa5ff7776
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 166 additions and 233 deletions

View file

@ -4,6 +4,7 @@ import com.intellij.openapi.diagnostic.Logger;
import ee.carlrobert.codegpt.completions.you.YouUserManager;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
import ee.carlrobert.codegpt.settings.state.SettingsState;
@ -80,12 +81,12 @@ public class CompletionRequestHandler {
var requestProvider = new CompletionRequestProvider(conversation);
try {
if (settings.isUseLlamaService()) {
if (settings.getSelectedService() == ServiceType.LLAMA_CPP) {
return CompletionClientProvider.getLlamaClient()
.getChatCompletion(requestProvider.buildLlamaCompletionRequest(message), eventListener);
}
if (settings.isUseYouService()) {
if (settings.getSelectedService() == ServiceType.YOU) {
var sessionId = "";
var accessToken = "";
var youUserManager = YouUserManager.getInstance();
@ -103,7 +104,7 @@ public class CompletionRequestHandler {
.getChatCompletion(request, eventListener);
}
if (settings.isUseAzureService()) {
if (settings.getSelectedService() == ServiceType.AZURE) {
var azureSettings = AzureSettingsState.getInstance();
return CompletionClientProvider.getAzureClient().getChatCompletion(
requestProvider.buildOpenAIChatCompletionRequest(
@ -151,7 +152,7 @@ public class CompletionRequestHandler {
conversation,
message,
isRetry,
settings.isUseYouService() ?
settings.getSelectedService() == ServiceType.YOU ?
new YouRequestCompletionEventListener() :
new BaseCompletionEventListener());
} catch (TotalUsageExceededException e) {
@ -212,20 +213,10 @@ public class CompletionRequestHandler {
}
private void sendInfo(SettingsState settings) {
var service = "openai";
if (settings.isUseAzureService()) {
service = "azure";
}
if (settings.isUseYouService()) {
service = "you";
}
if (settings.isUseLlamaService()) {
service = "llama";
}
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", conversation.getId().toString())
.property("model", conversation.getModel())
.property("service", service)
.property("service", settings.getSelectedService().getCode().toLowerCase())
.send();
}

View file

@ -11,6 +11,7 @@ import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.ConversationsState;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationState;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.SettingsState;
import ee.carlrobert.codegpt.settings.state.YouSettingsState;
@ -149,7 +150,7 @@ public class CompletionRequestProvider {
messages.add(new OpenAIChatCompletionMessage("user", message.getPrompt()));
}
if (SettingsState.getInstance().isUseYouService()) {
if (SettingsState.getInstance().getSelectedService() == ServiceType.YOU) {
return messages;
}

View file

@ -5,6 +5,7 @@ import static java.util.stream.Collectors.toList;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
@ -43,17 +44,9 @@ public final class ConversationService {
var conversation = new Conversation();
conversation.setId(UUID.randomUUID());
conversation.setClientCode(clientCode);
if (settings.isUseYouService()) {
conversation.setModel("YouCode");
} else if (settings.isUseAzureService()) {
conversation.setModel(AzureSettingsState.getInstance().getModel());
} else if (settings.isUseOpenAIService()) {
conversation.setModel(OpenAISettingsState.getInstance().getModel());
} else {
conversation.setModel(LlamaSettingsState.getInstance().getHuggingFaceModel().getCode());
}
conversation.setCreatedOn(LocalDateTime.now());
conversation.setUpdatedOn(LocalDateTime.now());
conversation.setModel(getModelForSelectedService(settings.getSelectedService()));
return conversation;
}
@ -121,22 +114,9 @@ public final class ConversationService {
conversationState.setCurrentConversation(conversation);
}
private String getClientCode() {
var settings = SettingsState.getInstance();
if (settings.isUseOpenAIService()) {
return "chat.completion";
}
if (settings.isUseAzureService()) {
return "azure.chat.completion";
}
if (settings.isUseLlamaService()) {
return "llama.chat.completion";
}
return "you.chat.completion";
}
public Conversation startConversation() {
var conversation = createConversation(getClientCode());
var completionCode = SettingsState.getInstance().getSelectedService().getCompletionCode();
var conversation = createConversation(completionCode);
conversationState.setCurrentConversation(conversation);
addConversation(conversation);
return conversation;
@ -147,32 +127,6 @@ public final class ConversationService {
conversationState.setCurrentConversation(null);
}
public Optional<Conversation> getPreviousConversation() {
return tryGetNextOrPreviousConversation(true);
}
public Optional<Conversation> getNextConversation() {
return tryGetNextOrPreviousConversation(false);
}
private Optional<Conversation> tryGetNextOrPreviousConversation(boolean isPrevious) {
var currentConversation = ConversationsState.getCurrentConversation();
if (currentConversation != null) {
var sortedConversations = getSortedConversations();
for (int i = 0; i < sortedConversations.size(); i++) {
var conversation = sortedConversations.get(i);
if (conversation != null && conversation.getId().equals(currentConversation.getId())) {
// higher index indicates older conversation
var previousIndex = isPrevious ? i + 1 : i - 1;
if (isPrevious ? previousIndex < sortedConversations.size() : previousIndex != -1) {
return Optional.of(sortedConversations.get(previousIndex));
}
}
}
}
return Optional.empty();
}
public void deleteConversation(Conversation conversation) {
var iterator = conversationState.getConversationsMapping()
.get(conversation.getClientCode())
@ -205,4 +159,48 @@ public final class ConversationService {
conversation.discardTokenLimits();
saveConversation(conversation);
}
public Optional<Conversation> getPreviousConversation() {
return tryGetNextOrPreviousConversation(true);
}
public Optional<Conversation> getNextConversation() {
return tryGetNextOrPreviousConversation(false);
}
private Optional<Conversation> tryGetNextOrPreviousConversation(boolean isPrevious) {
var currentConversation = ConversationsState.getCurrentConversation();
if (currentConversation != null) {
var sortedConversations = getSortedConversations();
for (int i = 0; i < sortedConversations.size(); i++) {
var conversation = sortedConversations.get(i);
if (conversation != null && conversation.getId().equals(currentConversation.getId())) {
// higher index indicates older conversation
var previousIndex = isPrevious ? i + 1 : i - 1;
if (isPrevious ? previousIndex < sortedConversations.size() : previousIndex != -1) {
return Optional.of(sortedConversations.get(previousIndex));
}
}
}
}
return Optional.empty();
}
private static String getModelForSelectedService(ServiceType serviceType) {
switch (serviceType) {
case OPENAI:
return OpenAISettingsState.getInstance().getModel();
case AZURE:
return AzureSettingsState.getInstance().getModel();
case YOU:
return "YouCode";
case LLAMA_CPP:
var llamaSettings = LlamaSettingsState.getInstance();
return llamaSettings.isUseCustomModel() ?
llamaSettings.getCustomLlamaModelPath() :
llamaSettings.getHuggingFaceModel().getCode();
default:
throw new RuntimeException("Could not find corresponding service mapping");
}
}
}

View file

@ -39,7 +39,7 @@ public class SettingsComponent {
cards.add(serviceSelectionForm.getLlamaServiceSectionPanel(), ServiceType.LLAMA_CPP.getCode());
var serviceComboBoxModel = new DefaultComboBoxModel<ServiceType>();
serviceComboBoxModel.addAll(Arrays.stream(ServiceType.values())
.filter(it -> !"LLAMA_CPP".equals(it.getCode()) || SystemInfoRt.isUnix)
.filter(it -> ServiceType.LLAMA_CPP != it || SystemInfoRt.isUnix)
.collect(toList()));
serviceComboBox = new ComboBox<>(serviceComboBoxModel);
serviceComboBox.setSelectedItem(ServiceType.OPENAI);

View file

@ -1,10 +1,5 @@
package ee.carlrobert.codegpt.settings;
import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;
import com.intellij.openapi.Disposable;
import com.intellij.openapi.options.Configurable;
import com.intellij.openapi.util.Disposer;
@ -12,7 +7,6 @@ import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.conversations.ConversationsState;
import ee.carlrobert.codegpt.credentials.AzureCredentialsManager;
import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
@ -98,13 +92,7 @@ public class SettingsConfigurable implements Configurable {
.setAzureActiveDirectoryToken(serviceSelectionForm.getAzureActiveDirectoryToken());
settings.setDisplayName(settingsComponent.getDisplayName());
// TODO: Store as single enum value
settings.setUseOpenAIService(settingsComponent.getSelectedService() == OPENAI);
settings.setUseAzureService(settingsComponent.getSelectedService() == ServiceType.AZURE);
settings.setUseYouService(settingsComponent.getSelectedService() == ServiceType.YOU);
YouSettingsState.getInstance()
.setDisplayWebSearchResults(serviceSelectionForm.isDisplayWebSearchResults());
settings.setUseLlamaService(settingsComponent.getSelectedService() == ServiceType.LLAMA_CPP);
settings.setSelectedService(settingsComponent.getSelectedService());
var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm();
llamaSettings.setCustomLlamaModelPath(llamaModelPreferencesForm.getCustomLlamaModelPath());
@ -116,12 +104,14 @@ public class SettingsConfigurable implements Configurable {
openAISettings.apply(serviceSelectionForm);
azureSettings.apply(serviceSelectionForm);
YouSettingsState.getInstance()
.setDisplayWebSearchResults(serviceSelectionForm.isDisplayWebSearchResults());
if (serviceChanged || modelChanged) {
resetActiveTab();
if (serviceChanged) {
TelemetryAction.SETTINGS_CHANGED.createActionMessage()
.property("service", getServiceCode())
.property("service", settingsComponent.getSelectedService().getCode().toLowerCase())
.send();
}
}
@ -137,20 +127,8 @@ public class SettingsConfigurable implements Configurable {
// settingsComponent.setEmail(settings.getEmail());
settingsComponent.setDisplayName(settings.getDisplayName());
settingsComponent.setSelectedService(settings.getSelectedService());
// TODO
if (settings.isUseOpenAIService()) {
settingsComponent.setSelectedService(OPENAI);
}
if (settings.isUseAzureService()) {
settingsComponent.setSelectedService(ServiceType.AZURE);
}
if (settings.isUseYouService()) {
settingsComponent.setSelectedService(ServiceType.YOU);
}
if (settings.isUseLlamaService()) {
settingsComponent.setSelectedService(ServiceType.LLAMA_CPP);
}
var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm();
llamaModelPreferencesForm.setSelectedModel(llamaSettings.getHuggingFaceModel());
llamaModelPreferencesForm.setCustomLlamaModelPath(llamaSettings.getCustomLlamaModelPath());
@ -174,10 +152,7 @@ public class SettingsConfigurable implements Configurable {
}
private boolean isServiceChanged(SettingsState settings) {
return (settingsComponent.getSelectedService() == OPENAI) != settings.isUseOpenAIService() ||
(settingsComponent.getSelectedService() == AZURE) != settings.isUseAzureService() ||
(settingsComponent.getSelectedService() == YOU) != settings.isUseYouService() ||
(settingsComponent.getSelectedService() == LLAMA_CPP) != settings.isUseLlamaService();
return settingsComponent.getSelectedService() != settings.getSelectedService();
}
private void resetActiveTab() {
@ -189,20 +164,4 @@ public class SettingsConfigurable implements Configurable {
project.getService(StandardChatToolWindowContentManager.class).resetActiveTab();
}
private String getServiceCode() {
if (settingsComponent.getSelectedService() == OPENAI) {
return "openai";
}
if (settingsComponent.getSelectedService() == AZURE) {
return "azure";
}
if (settingsComponent.getSelectedService() == YOU) {
return "you";
}
if (settingsComponent.getSelectedService() == LLAMA_CPP) {
return "llama.cpp";
}
return null;
}
}

View file

@ -3,17 +3,19 @@ package ee.carlrobert.codegpt.settings.service;
import ee.carlrobert.codegpt.CodeGPTBundle;
public enum ServiceType {
OPENAI("OPENAI", CodeGPTBundle.get("service.openai.title")),
AZURE("AZURE", CodeGPTBundle.get("service.azure.title")),
YOU("YOU", CodeGPTBundle.get("service.you.title")),
LLAMA_CPP("LLAMA_CPP", CodeGPTBundle.get("service.llama.title"));
OPENAI("OPENAI", CodeGPTBundle.get("service.openai.title"), "chat.completion"),
AZURE("AZURE", CodeGPTBundle.get("service.azure.title"), "azure.chat.completion"),
YOU("YOU", CodeGPTBundle.get("service.you.title"), "you.chat.completion"),
LLAMA_CPP("LLAMA_CPP", CodeGPTBundle.get("service.llama.title"), "llama.chat.completion");
private final String code;
private final String label;
private final String completionCode;
ServiceType(String code, String label) {
ServiceType(String code, String label, String completionCode) {
this.code = code;
this.label = label;
this.completionCode = completionCode;
}
public String getCode() {
@ -24,6 +26,10 @@ public enum ServiceType {
return label;
}
public String getCompletionCode() {
return completionCode;
}
@Override
public String toString() {
return label;

View file

@ -7,6 +7,7 @@ import com.intellij.openapi.components.Storage;
import com.intellij.util.xmlb.XmlSerializerUtil;
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import org.jetbrains.annotations.NotNull;
@State(name = "CodeGPT_GeneralSettings_210", storages = @Storage("CodeGPT_GeneralSettings_210.xml"))
@ -15,10 +16,7 @@ public class SettingsState implements PersistentStateComponent<SettingsState> {
private String email = "";
private String displayName = "";
private boolean previouslySignedIn;
private boolean useOpenAIService = true;
private boolean useAzureService;
private boolean useYouService;
private boolean useLlamaService;
private ServiceType selectedService = ServiceType.OPENAI;
public SettingsState() {
}
@ -40,20 +38,28 @@ public class SettingsState implements PersistentStateComponent<SettingsState> {
public void sync(Conversation conversation) {
var clientCode = conversation.getClientCode();
if ("chat.completion".equals(clientCode)) {
setSelectedService(ServiceType.OPENAI);
OpenAISettingsState.getInstance().setModel(conversation.getModel());
}
if ("azure.chat.completion".equals(clientCode)) {
setSelectedService(ServiceType.AZURE);
AzureSettingsState.getInstance().setModel(conversation.getModel());
}
if ("llama.chat.completion".equals(clientCode)) {
LlamaSettingsState.getInstance().setHuggingFaceModel(
HuggingFaceModel.valueOf(conversation.getModel()));
setSelectedService(ServiceType.LLAMA_CPP);
var llamaSettings = LlamaSettingsState.getInstance();
try {
var huggingFaceModel = HuggingFaceModel.valueOf(conversation.getModel());
llamaSettings.setHuggingFaceModel(huggingFaceModel);
llamaSettings.setUseCustomModel(false);
} catch (IllegalArgumentException ignore) {
llamaSettings.setCustomLlamaModelPath(conversation.getModel());
llamaSettings.setUseCustomModel(true);
}
}
if ("you.chat.completion".equals(clientCode)) {
setSelectedService(ServiceType.YOU);
}
setUseOpenAIService("chat.completion".equals(clientCode));
setUseAzureService("azure.chat.completion".equals(clientCode));
setUseYouService("you.chat.completion".equals(clientCode));
setUseLlamaService("llama.chat.completion".equals(clientCode));
}
public String getEmail() {
@ -87,35 +93,11 @@ public class SettingsState implements PersistentStateComponent<SettingsState> {
this.previouslySignedIn = previouslySignedIn;
}
public boolean isUseOpenAIService() {
return useOpenAIService;
public ServiceType getSelectedService() {
return selectedService;
}
public void setUseOpenAIService(boolean useOpenAIService) {
this.useOpenAIService = useOpenAIService;
}
public boolean isUseAzureService() {
return useAzureService;
}
public void setUseAzureService(boolean useAzureService) {
this.useAzureService = useAzureService;
}
public boolean isUseYouService() {
return useYouService;
}
public void setUseYouService(boolean useYouService) {
this.useYouService = useYouService;
}
public boolean isUseLlamaService() {
return useLlamaService;
}
public void setUseLlamaService(boolean useLlamaService) {
this.useLlamaService = useLlamaService;
public void setSelectedService(ServiceType selectedService) {
this.selectedService = selectedService;
}
}

View file

@ -29,6 +29,7 @@ import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.credentials.AzureCredentialsManager;
import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
@ -115,7 +116,7 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
scrollablePanel.removeAll();
scrollablePanel.add(getLandingView());
var youUserManager = YouUserManager.getInstance();
if (SettingsState.getInstance().isUseYouService() &&
if (SettingsState.getInstance().getSelectedService() == ServiceType.YOU &&
(!youUserManager.isAuthenticated() || !youUserManager.isSubscribed())) {
scrollablePanel.add(new ResponsePanel().addContent(createTextPane()));
}
@ -171,10 +172,10 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
private boolean isRequestAllowed() {
var settings = SettingsState.getInstance();
if (settings.isUseAzureService()) {
if (settings.getSelectedService() == ServiceType.AZURE) {
return AzureCredentialsManager.getInstance().isCredentialSet();
}
if (settings.isUseOpenAIService()) {
if (settings.getSelectedService() == ServiceType.OPENAI) {
return OpenAICredentialsManager.getInstance().isApiKeySet();
}
return true;
@ -244,7 +245,7 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
requestHandler.addErrorListener((error, ex) -> {
try {
if ("insufficient_quota".equals(error.getCode())) {
if (SettingsState.getInstance().isUseOpenAIService()) {
if (SettingsState.getInstance().getSelectedService() == ServiceType.OPENAI) {
OpenAISettingsState.getInstance().setOpenAIQuotaExceeded(true);
}
responseContainer.displayQuotaExceeded();
@ -377,7 +378,11 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
var model = getModel();
var modelIconWrapper = JBUI.Panels
.simplePanel(new ModelIconLabel(getClientCode(), model))
.simplePanel(new ModelIconLabel(
SettingsState.getInstance()
.getSelectedService()
.getCompletionCode(),
model))
.withBorder(Borders.emptyRight(4))
.withBackground(getPanelBackgroundColor());
@ -388,20 +393,18 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
wrapper.setBackground(getPanelBackgroundColor());
wrapper.add(userPromptTextArea, BorderLayout.SOUTH);
if (model != null) {
var header = new JPanel(new BorderLayout());
header.setBackground(getPanelBackgroundColor());
header.setBorder(JBUI.Borders.emptyBottom(8));
if ("YouCode".equals(model)) {
var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect();
subscribeToYouModelChangeTopic();
subscribeToYouSubscriptionTopic(messageBusConnection);
subscribeToSignedOutTopic(messageBusConnection);
header.add(gpt4CheckBox, BorderLayout.LINE_START);
}
header.add(modelIconWrapper, BorderLayout.LINE_END);
wrapper.add(header);
var header = new JPanel(new BorderLayout());
header.setBackground(getPanelBackgroundColor());
header.setBorder(JBUI.Borders.emptyBottom(8));
if ("YouCode".equals(model)) {
var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect();
subscribeToYouModelChangeTopic();
subscribeToYouSubscriptionTopic(messageBusConnection);
subscribeToSignedOutTopic(messageBusConnection);
header.add(gpt4CheckBox, BorderLayout.LINE_START);
}
header.add(modelIconWrapper, BorderLayout.LINE_END);
wrapper.add(header);
rootPanel.add(wrapper, gbc);
userPromptTextArea.requestFocusInWindow();
@ -459,35 +462,18 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
return CodeGPTBundle.get("toolwindow.chat.youProCheckBox.notAllowed");
}
private String getClientCode() {
private String getModel() {
var settings = SettingsState.getInstance();
if (settings.isUseOpenAIService()) {
return "chat.completion";
}
if (settings.isUseAzureService()) {
return "azure.chat.completion";
}
if (settings.isUseYouService()) {
return "you.chat.completion";
}
if (settings.isUseLlamaService()) {
return "llama.chat.completion";
}
return null;
}
private @Nullable String getModel() {
var settings = SettingsState.getInstance();
if (settings.isUseOpenAIService()) {
if (settings.getSelectedService() == ServiceType.OPENAI) {
return OpenAISettingsState.getInstance().getModel();
}
if (settings.isUseAzureService()) {
if (settings.getSelectedService() == ServiceType.AZURE) {
return AzureSettingsState.getInstance().getModel();
}
if (settings.isUseYouService()) {
if (settings.getSelectedService() == ServiceType.YOU) {
return "YouCode";
}
if (settings.isUseLlamaService()) {
if (settings.getSelectedService() == ServiceType.LLAMA_CPP) {
var llamaSettings = LlamaSettingsState.getInstance();
if (llamaSettings.isUseCustomModel()) {
var filePath = llamaSettings.getCustomLlamaModelPath();