diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java index abdf8bba..316b3f8b 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java @@ -4,6 +4,7 @@ import com.intellij.openapi.diagnostic.Logger; import ee.carlrobert.codegpt.completions.you.YouUserManager; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.message.Message; +import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; @@ -80,12 +81,12 @@ public class CompletionRequestHandler { var requestProvider = new CompletionRequestProvider(conversation); try { - if (settings.isUseLlamaService()) { + if (settings.getSelectedService() == ServiceType.LLAMA_CPP) { return CompletionClientProvider.getLlamaClient() .getChatCompletion(requestProvider.buildLlamaCompletionRequest(message), eventListener); } - if (settings.isUseYouService()) { + if (settings.getSelectedService() == ServiceType.YOU) { var sessionId = ""; var accessToken = ""; var youUserManager = YouUserManager.getInstance(); @@ -103,7 +104,7 @@ public class CompletionRequestHandler { .getChatCompletion(request, eventListener); } - if (settings.isUseAzureService()) { + if (settings.getSelectedService() == ServiceType.AZURE) { var azureSettings = AzureSettingsState.getInstance(); return CompletionClientProvider.getAzureClient().getChatCompletion( requestProvider.buildOpenAIChatCompletionRequest( @@ -151,7 +152,7 @@ public class CompletionRequestHandler { conversation, message, isRetry, - settings.isUseYouService() ? + settings.getSelectedService() == ServiceType.YOU ? new YouRequestCompletionEventListener() : new BaseCompletionEventListener()); } catch (TotalUsageExceededException e) { @@ -212,20 +213,10 @@ public class CompletionRequestHandler { } private void sendInfo(SettingsState settings) { - var service = "openai"; - if (settings.isUseAzureService()) { - service = "azure"; - } - if (settings.isUseYouService()) { - service = "you"; - } - if (settings.isUseLlamaService()) { - service = "llama"; - } TelemetryAction.COMPLETION.createActionMessage() .property("conversationId", conversation.getId().toString()) .property("model", conversation.getModel()) - .property("service", service) + .property("service", settings.getSelectedService().getCode().toLowerCase()) .send(); } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index 9fc0adb0..1df8281f 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -11,6 +11,7 @@ import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.ConversationsState; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; +import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; import ee.carlrobert.codegpt.settings.state.YouSettingsState; @@ -149,7 +150,7 @@ public class CompletionRequestProvider { messages.add(new OpenAIChatCompletionMessage("user", message.getPrompt())); } - if (SettingsState.getInstance().isUseYouService()) { + if (SettingsState.getInstance().getSelectedService() == ServiceType.YOU) { return messages; } diff --git a/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java b/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java index bac5ce76..2ec2dc89 100644 --- a/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java +++ b/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java @@ -5,6 +5,7 @@ import static java.util.stream.Collectors.toList; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; import ee.carlrobert.codegpt.conversations.message.Message; +import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; @@ -43,17 +44,9 @@ public final class ConversationService { var conversation = new Conversation(); conversation.setId(UUID.randomUUID()); conversation.setClientCode(clientCode); - if (settings.isUseYouService()) { - conversation.setModel("YouCode"); - } else if (settings.isUseAzureService()) { - conversation.setModel(AzureSettingsState.getInstance().getModel()); - } else if (settings.isUseOpenAIService()) { - conversation.setModel(OpenAISettingsState.getInstance().getModel()); - } else { - conversation.setModel(LlamaSettingsState.getInstance().getHuggingFaceModel().getCode()); - } conversation.setCreatedOn(LocalDateTime.now()); conversation.setUpdatedOn(LocalDateTime.now()); + conversation.setModel(getModelForSelectedService(settings.getSelectedService())); return conversation; } @@ -121,22 +114,9 @@ public final class ConversationService { conversationState.setCurrentConversation(conversation); } - private String getClientCode() { - var settings = SettingsState.getInstance(); - if (settings.isUseOpenAIService()) { - return "chat.completion"; - } - if (settings.isUseAzureService()) { - return "azure.chat.completion"; - } - if (settings.isUseLlamaService()) { - return "llama.chat.completion"; - } - return "you.chat.completion"; - } - public Conversation startConversation() { - var conversation = createConversation(getClientCode()); + var completionCode = SettingsState.getInstance().getSelectedService().getCompletionCode(); + var conversation = createConversation(completionCode); conversationState.setCurrentConversation(conversation); addConversation(conversation); return conversation; @@ -147,32 +127,6 @@ public final class ConversationService { conversationState.setCurrentConversation(null); } - public Optional getPreviousConversation() { - return tryGetNextOrPreviousConversation(true); - } - - public Optional getNextConversation() { - return tryGetNextOrPreviousConversation(false); - } - - private Optional tryGetNextOrPreviousConversation(boolean isPrevious) { - var currentConversation = ConversationsState.getCurrentConversation(); - if (currentConversation != null) { - var sortedConversations = getSortedConversations(); - for (int i = 0; i < sortedConversations.size(); i++) { - var conversation = sortedConversations.get(i); - if (conversation != null && conversation.getId().equals(currentConversation.getId())) { - // higher index indicates older conversation - var previousIndex = isPrevious ? i + 1 : i - 1; - if (isPrevious ? previousIndex < sortedConversations.size() : previousIndex != -1) { - return Optional.of(sortedConversations.get(previousIndex)); - } - } - } - } - return Optional.empty(); - } - public void deleteConversation(Conversation conversation) { var iterator = conversationState.getConversationsMapping() .get(conversation.getClientCode()) @@ -205,4 +159,48 @@ public final class ConversationService { conversation.discardTokenLimits(); saveConversation(conversation); } + + public Optional getPreviousConversation() { + return tryGetNextOrPreviousConversation(true); + } + + public Optional getNextConversation() { + return tryGetNextOrPreviousConversation(false); + } + + private Optional tryGetNextOrPreviousConversation(boolean isPrevious) { + var currentConversation = ConversationsState.getCurrentConversation(); + if (currentConversation != null) { + var sortedConversations = getSortedConversations(); + for (int i = 0; i < sortedConversations.size(); i++) { + var conversation = sortedConversations.get(i); + if (conversation != null && conversation.getId().equals(currentConversation.getId())) { + // higher index indicates older conversation + var previousIndex = isPrevious ? i + 1 : i - 1; + if (isPrevious ? previousIndex < sortedConversations.size() : previousIndex != -1) { + return Optional.of(sortedConversations.get(previousIndex)); + } + } + } + } + return Optional.empty(); + } + + private static String getModelForSelectedService(ServiceType serviceType) { + switch (serviceType) { + case OPENAI: + return OpenAISettingsState.getInstance().getModel(); + case AZURE: + return AzureSettingsState.getInstance().getModel(); + case YOU: + return "YouCode"; + case LLAMA_CPP: + var llamaSettings = LlamaSettingsState.getInstance(); + return llamaSettings.isUseCustomModel() ? + llamaSettings.getCustomLlamaModelPath() : + llamaSettings.getHuggingFaceModel().getCode(); + default: + throw new RuntimeException("Could not find corresponding service mapping"); + } + } } \ No newline at end of file diff --git a/src/main/java/ee/carlrobert/codegpt/settings/SettingsComponent.java b/src/main/java/ee/carlrobert/codegpt/settings/SettingsComponent.java index b016af88..3fa06753 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/SettingsComponent.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/SettingsComponent.java @@ -39,7 +39,7 @@ public class SettingsComponent { cards.add(serviceSelectionForm.getLlamaServiceSectionPanel(), ServiceType.LLAMA_CPP.getCode()); var serviceComboBoxModel = new DefaultComboBoxModel(); serviceComboBoxModel.addAll(Arrays.stream(ServiceType.values()) - .filter(it -> !"LLAMA_CPP".equals(it.getCode()) || SystemInfoRt.isUnix) + .filter(it -> ServiceType.LLAMA_CPP != it || SystemInfoRt.isUnix) .collect(toList())); serviceComboBox = new ComboBox<>(serviceComboBoxModel); serviceComboBox.setSelectedItem(ServiceType.OPENAI); diff --git a/src/main/java/ee/carlrobert/codegpt/settings/SettingsConfigurable.java b/src/main/java/ee/carlrobert/codegpt/settings/SettingsConfigurable.java index 641c1951..ee9cc41f 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/SettingsConfigurable.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/SettingsConfigurable.java @@ -1,10 +1,5 @@ package ee.carlrobert.codegpt.settings; -import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE; -import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP; -import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI; -import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU; - import com.intellij.openapi.Disposable; import com.intellij.openapi.options.Configurable; import com.intellij.openapi.util.Disposer; @@ -12,7 +7,6 @@ import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.conversations.ConversationsState; import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; -import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; @@ -98,13 +92,7 @@ public class SettingsConfigurable implements Configurable { .setAzureActiveDirectoryToken(serviceSelectionForm.getAzureActiveDirectoryToken()); settings.setDisplayName(settingsComponent.getDisplayName()); - // TODO: Store as single enum value - settings.setUseOpenAIService(settingsComponent.getSelectedService() == OPENAI); - settings.setUseAzureService(settingsComponent.getSelectedService() == ServiceType.AZURE); - settings.setUseYouService(settingsComponent.getSelectedService() == ServiceType.YOU); - YouSettingsState.getInstance() - .setDisplayWebSearchResults(serviceSelectionForm.isDisplayWebSearchResults()); - settings.setUseLlamaService(settingsComponent.getSelectedService() == ServiceType.LLAMA_CPP); + settings.setSelectedService(settingsComponent.getSelectedService()); var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm(); llamaSettings.setCustomLlamaModelPath(llamaModelPreferencesForm.getCustomLlamaModelPath()); @@ -116,12 +104,14 @@ public class SettingsConfigurable implements Configurable { openAISettings.apply(serviceSelectionForm); azureSettings.apply(serviceSelectionForm); + YouSettingsState.getInstance() + .setDisplayWebSearchResults(serviceSelectionForm.isDisplayWebSearchResults()); if (serviceChanged || modelChanged) { resetActiveTab(); if (serviceChanged) { TelemetryAction.SETTINGS_CHANGED.createActionMessage() - .property("service", getServiceCode()) + .property("service", settingsComponent.getSelectedService().getCode().toLowerCase()) .send(); } } @@ -137,20 +127,8 @@ public class SettingsConfigurable implements Configurable { // settingsComponent.setEmail(settings.getEmail()); settingsComponent.setDisplayName(settings.getDisplayName()); + settingsComponent.setSelectedService(settings.getSelectedService()); - // TODO - if (settings.isUseOpenAIService()) { - settingsComponent.setSelectedService(OPENAI); - } - if (settings.isUseAzureService()) { - settingsComponent.setSelectedService(ServiceType.AZURE); - } - if (settings.isUseYouService()) { - settingsComponent.setSelectedService(ServiceType.YOU); - } - if (settings.isUseLlamaService()) { - settingsComponent.setSelectedService(ServiceType.LLAMA_CPP); - } var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm(); llamaModelPreferencesForm.setSelectedModel(llamaSettings.getHuggingFaceModel()); llamaModelPreferencesForm.setCustomLlamaModelPath(llamaSettings.getCustomLlamaModelPath()); @@ -174,10 +152,7 @@ public class SettingsConfigurable implements Configurable { } private boolean isServiceChanged(SettingsState settings) { - return (settingsComponent.getSelectedService() == OPENAI) != settings.isUseOpenAIService() || - (settingsComponent.getSelectedService() == AZURE) != settings.isUseAzureService() || - (settingsComponent.getSelectedService() == YOU) != settings.isUseYouService() || - (settingsComponent.getSelectedService() == LLAMA_CPP) != settings.isUseLlamaService(); + return settingsComponent.getSelectedService() != settings.getSelectedService(); } private void resetActiveTab() { @@ -189,20 +164,4 @@ public class SettingsConfigurable implements Configurable { project.getService(StandardChatToolWindowContentManager.class).resetActiveTab(); } - - private String getServiceCode() { - if (settingsComponent.getSelectedService() == OPENAI) { - return "openai"; - } - if (settingsComponent.getSelectedService() == AZURE) { - return "azure"; - } - if (settingsComponent.getSelectedService() == YOU) { - return "you"; - } - if (settingsComponent.getSelectedService() == LLAMA_CPP) { - return "llama.cpp"; - } - return null; - } } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java b/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java index 764ed33e..ec83a57a 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java @@ -3,17 +3,19 @@ package ee.carlrobert.codegpt.settings.service; import ee.carlrobert.codegpt.CodeGPTBundle; public enum ServiceType { - OPENAI("OPENAI", CodeGPTBundle.get("service.openai.title")), - AZURE("AZURE", CodeGPTBundle.get("service.azure.title")), - YOU("YOU", CodeGPTBundle.get("service.you.title")), - LLAMA_CPP("LLAMA_CPP", CodeGPTBundle.get("service.llama.title")); + OPENAI("OPENAI", CodeGPTBundle.get("service.openai.title"), "chat.completion"), + AZURE("AZURE", CodeGPTBundle.get("service.azure.title"), "azure.chat.completion"), + YOU("YOU", CodeGPTBundle.get("service.you.title"), "you.chat.completion"), + LLAMA_CPP("LLAMA_CPP", CodeGPTBundle.get("service.llama.title"), "llama.chat.completion"); private final String code; private final String label; + private final String completionCode; - ServiceType(String code, String label) { + ServiceType(String code, String label, String completionCode) { this.code = code; this.label = label; + this.completionCode = completionCode; } public String getCode() { @@ -24,6 +26,10 @@ public enum ServiceType { return label; } + public String getCompletionCode() { + return completionCode; + } + @Override public String toString() { return label; diff --git a/src/main/java/ee/carlrobert/codegpt/settings/state/SettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/state/SettingsState.java index feb1183a..6fa4d4c4 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/state/SettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/state/SettingsState.java @@ -7,6 +7,7 @@ import com.intellij.openapi.components.Storage; import com.intellij.util.xmlb.XmlSerializerUtil; import ee.carlrobert.codegpt.completions.HuggingFaceModel; import ee.carlrobert.codegpt.conversations.Conversation; +import ee.carlrobert.codegpt.settings.service.ServiceType; import org.jetbrains.annotations.NotNull; @State(name = "CodeGPT_GeneralSettings_210", storages = @Storage("CodeGPT_GeneralSettings_210.xml")) @@ -15,10 +16,7 @@ public class SettingsState implements PersistentStateComponent { private String email = ""; private String displayName = ""; private boolean previouslySignedIn; - private boolean useOpenAIService = true; - private boolean useAzureService; - private boolean useYouService; - private boolean useLlamaService; + private ServiceType selectedService = ServiceType.OPENAI; public SettingsState() { } @@ -40,20 +38,28 @@ public class SettingsState implements PersistentStateComponent { public void sync(Conversation conversation) { var clientCode = conversation.getClientCode(); if ("chat.completion".equals(clientCode)) { + setSelectedService(ServiceType.OPENAI); OpenAISettingsState.getInstance().setModel(conversation.getModel()); } if ("azure.chat.completion".equals(clientCode)) { + setSelectedService(ServiceType.AZURE); AzureSettingsState.getInstance().setModel(conversation.getModel()); } if ("llama.chat.completion".equals(clientCode)) { - LlamaSettingsState.getInstance().setHuggingFaceModel( - HuggingFaceModel.valueOf(conversation.getModel())); + setSelectedService(ServiceType.LLAMA_CPP); + var llamaSettings = LlamaSettingsState.getInstance(); + try { + var huggingFaceModel = HuggingFaceModel.valueOf(conversation.getModel()); + llamaSettings.setHuggingFaceModel(huggingFaceModel); + llamaSettings.setUseCustomModel(false); + } catch (IllegalArgumentException ignore) { + llamaSettings.setCustomLlamaModelPath(conversation.getModel()); + llamaSettings.setUseCustomModel(true); + } + } + if ("you.chat.completion".equals(clientCode)) { + setSelectedService(ServiceType.YOU); } - - setUseOpenAIService("chat.completion".equals(clientCode)); - setUseAzureService("azure.chat.completion".equals(clientCode)); - setUseYouService("you.chat.completion".equals(clientCode)); - setUseLlamaService("llama.chat.completion".equals(clientCode)); } public String getEmail() { @@ -87,35 +93,11 @@ public class SettingsState implements PersistentStateComponent { this.previouslySignedIn = previouslySignedIn; } - public boolean isUseOpenAIService() { - return useOpenAIService; + public ServiceType getSelectedService() { + return selectedService; } - public void setUseOpenAIService(boolean useOpenAIService) { - this.useOpenAIService = useOpenAIService; - } - - public boolean isUseAzureService() { - return useAzureService; - } - - public void setUseAzureService(boolean useAzureService) { - this.useAzureService = useAzureService; - } - - public boolean isUseYouService() { - return useYouService; - } - - public void setUseYouService(boolean useYouService) { - this.useYouService = useYouService; - } - - public boolean isUseLlamaService() { - return useLlamaService; - } - - public void setUseLlamaService(boolean useLlamaService) { - this.useLlamaService = useLlamaService; + public void setSelectedService(ServiceType selectedService) { + this.selectedService = selectedService; } } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/BaseChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/BaseChatToolWindowTabPanel.java index 1609218a..ec96a670 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/BaseChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/BaseChatToolWindowTabPanel.java @@ -29,6 +29,7 @@ import ee.carlrobert.codegpt.conversations.ConversationService; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; +import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; @@ -115,7 +116,7 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan scrollablePanel.removeAll(); scrollablePanel.add(getLandingView()); var youUserManager = YouUserManager.getInstance(); - if (SettingsState.getInstance().isUseYouService() && + if (SettingsState.getInstance().getSelectedService() == ServiceType.YOU && (!youUserManager.isAuthenticated() || !youUserManager.isSubscribed())) { scrollablePanel.add(new ResponsePanel().addContent(createTextPane())); } @@ -171,10 +172,10 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan private boolean isRequestAllowed() { var settings = SettingsState.getInstance(); - if (settings.isUseAzureService()) { + if (settings.getSelectedService() == ServiceType.AZURE) { return AzureCredentialsManager.getInstance().isCredentialSet(); } - if (settings.isUseOpenAIService()) { + if (settings.getSelectedService() == ServiceType.OPENAI) { return OpenAICredentialsManager.getInstance().isApiKeySet(); } return true; @@ -244,7 +245,7 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan requestHandler.addErrorListener((error, ex) -> { try { if ("insufficient_quota".equals(error.getCode())) { - if (SettingsState.getInstance().isUseOpenAIService()) { + if (SettingsState.getInstance().getSelectedService() == ServiceType.OPENAI) { OpenAISettingsState.getInstance().setOpenAIQuotaExceeded(true); } responseContainer.displayQuotaExceeded(); @@ -377,7 +378,11 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan var model = getModel(); var modelIconWrapper = JBUI.Panels - .simplePanel(new ModelIconLabel(getClientCode(), model)) + .simplePanel(new ModelIconLabel( + SettingsState.getInstance() + .getSelectedService() + .getCompletionCode(), + model)) .withBorder(Borders.emptyRight(4)) .withBackground(getPanelBackgroundColor()); @@ -388,20 +393,18 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan wrapper.setBackground(getPanelBackgroundColor()); wrapper.add(userPromptTextArea, BorderLayout.SOUTH); - if (model != null) { - var header = new JPanel(new BorderLayout()); - header.setBackground(getPanelBackgroundColor()); - header.setBorder(JBUI.Borders.emptyBottom(8)); - if ("YouCode".equals(model)) { - var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect(); - subscribeToYouModelChangeTopic(); - subscribeToYouSubscriptionTopic(messageBusConnection); - subscribeToSignedOutTopic(messageBusConnection); - header.add(gpt4CheckBox, BorderLayout.LINE_START); - } - header.add(modelIconWrapper, BorderLayout.LINE_END); - wrapper.add(header); + var header = new JPanel(new BorderLayout()); + header.setBackground(getPanelBackgroundColor()); + header.setBorder(JBUI.Borders.emptyBottom(8)); + if ("YouCode".equals(model)) { + var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect(); + subscribeToYouModelChangeTopic(); + subscribeToYouSubscriptionTopic(messageBusConnection); + subscribeToSignedOutTopic(messageBusConnection); + header.add(gpt4CheckBox, BorderLayout.LINE_START); } + header.add(modelIconWrapper, BorderLayout.LINE_END); + wrapper.add(header); rootPanel.add(wrapper, gbc); userPromptTextArea.requestFocusInWindow(); @@ -459,35 +462,18 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan return CodeGPTBundle.get("toolwindow.chat.youProCheckBox.notAllowed"); } - private String getClientCode() { + private String getModel() { var settings = SettingsState.getInstance(); - if (settings.isUseOpenAIService()) { - return "chat.completion"; - } - if (settings.isUseAzureService()) { - return "azure.chat.completion"; - } - if (settings.isUseYouService()) { - return "you.chat.completion"; - } - if (settings.isUseLlamaService()) { - return "llama.chat.completion"; - } - return null; - } - - private @Nullable String getModel() { - var settings = SettingsState.getInstance(); - if (settings.isUseOpenAIService()) { + if (settings.getSelectedService() == ServiceType.OPENAI) { return OpenAISettingsState.getInstance().getModel(); } - if (settings.isUseAzureService()) { + if (settings.getSelectedService() == ServiceType.AZURE) { return AzureSettingsState.getInstance().getModel(); } - if (settings.isUseYouService()) { + if (settings.getSelectedService() == ServiceType.YOU) { return "YouCode"; } - if (settings.isUseLlamaService()) { + if (settings.getSelectedService() == ServiceType.LLAMA_CPP) { var llamaSettings = LlamaSettingsState.getInstance(); if (llamaSettings.isUseCustomModel()) { var filePath = llamaSettings.getCustomLlamaModelPath(); diff --git a/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java b/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java index e08bca5c..dbfbaddb 100644 --- a/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java +++ b/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java @@ -14,6 +14,7 @@ import ee.carlrobert.codegpt.conversations.ConversationService; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; +import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; import ee.carlrobert.llm.client.http.LocalCallbackServer; @@ -144,9 +145,7 @@ public class CompletionRequestProviderTest extends BasePlatformTestCase { public void testContextualSearch() { var conversation = ConversationService.getInstance().startConversation(); - var settings = SettingsState.getInstance(); - settings.setUseOpenAIService(true); - settings.setUseAzureService(false); + SettingsState.getInstance().setSelectedService(ServiceType.OPENAI); expectRequest("/v1/chat/completions", request -> { assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); diff --git a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java index d7427079..7fc79608 100644 --- a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java +++ b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java @@ -18,6 +18,7 @@ import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; +import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; @@ -59,11 +60,7 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { var conversation = ConversationService.getInstance().startConversation(); var requestHandler = new CompletionRequestHandler(); requestHandler.addRequestCompletedListener(message::setResponse); - var settings = SettingsState.getInstance(); - settings.setUseOpenAIService(true); - settings.setUseAzureService(false); - settings.setUseYouService(false); - settings.setUseLlamaService(false); + SettingsState.getInstance().setSelectedService(ServiceType.OPENAI); expectStreamRequest("/v1/chat/completions", request -> { assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); @@ -91,11 +88,7 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { public void testAzureChatCompletionCall() { AzureSettingsState.getInstance().setModel(OpenAIChatCompletionModel.GPT_3_5.getCode()); - var settings = SettingsState.getInstance(); - settings.setUseOpenAIService(false); - settings.setUseAzureService(true); - settings.setUseYouService(false); - settings.setUseLlamaService(false); + SettingsState.getInstance().setSelectedService(ServiceType.AZURE); var azureSettings = AzureSettingsState.getInstance(); azureSettings.setResourceName("TEST_RESOURCE_NAME"); azureSettings.setApiVersion("TEST_API_VERSION"); @@ -141,11 +134,7 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { conversation.addMessage(new Message("Ping", "Pong")); var requestHandler = new CompletionRequestHandler(); requestHandler.addRequestCompletedListener(message::setResponse); - var settings = SettingsState.getInstance(); - settings.setUseOpenAIService(false); - settings.setUseAzureService(false); - settings.setUseYouService(true); - settings.setUseLlamaService(false); + SettingsState.getInstance().setSelectedService(ServiceType.YOU); expectStreamRequest("/api/streamingSearch", request -> { assertThat(request.getMethod()).isEqualTo("GET"); assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch"); @@ -195,11 +184,7 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { conversation.addMessage(new Message("Ping", "Pong")); var requestHandler = new CompletionRequestHandler(); requestHandler.addRequestCompletedListener(message::setResponse); - var settings = SettingsState.getInstance(); - settings.setUseOpenAIService(false); - settings.setUseAzureService(false); - settings.setUseYouService(false); - settings.setUseLlamaService(true); + SettingsState.getInstance().setSelectedService(ServiceType.LLAMA_CPP); expectStreamRequest("/completion", request -> { assertThat(request.getBody()) .extracting( diff --git a/src/test/java/ee/carlrobert/codegpt/conversations/ConversationsStateTest.java b/src/test/java/ee/carlrobert/codegpt/conversations/ConversationsStateTest.java index 3f1c1fb0..8b8717c4 100644 --- a/src/test/java/ee/carlrobert/codegpt/conversations/ConversationsStateTest.java +++ b/src/test/java/ee/carlrobert/codegpt/conversations/ConversationsStateTest.java @@ -4,6 +4,7 @@ import static org.assertj.core.api.Assertions.assertThat; import com.intellij.testFramework.fixtures.BasePlatformTestCase; import ee.carlrobert.codegpt.conversations.message.Message; +import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; import ee.carlrobert.llm.client.openai.completion.chat.OpenAIChatCompletionModel; @@ -12,9 +13,7 @@ public class ConversationsStateTest extends BasePlatformTestCase { public void testStartNewDefaultConversation() { var settings = SettingsState.getInstance(); - settings.setUseOpenAIService(true); - settings.setUseAzureService(false); - settings.setUseYouService(false); + settings.setSelectedService(ServiceType.OPENAI); OpenAISettingsState.getInstance().setModel(OpenAIChatCompletionModel.GPT_3_5.getCode()); var conversation = ConversationService.getInstance().startConversation(); diff --git a/src/test/java/ee/carlrobert/codegpt/settings/state/SettingsStateTest.java b/src/test/java/ee/carlrobert/codegpt/settings/state/SettingsStateTest.java index cb2a3c90..0273bfeb 100644 --- a/src/test/java/ee/carlrobert/codegpt/settings/state/SettingsStateTest.java +++ b/src/test/java/ee/carlrobert/codegpt/settings/state/SettingsStateTest.java @@ -1,10 +1,13 @@ package ee.carlrobert.codegpt.settings.state; +import static ee.carlrobert.codegpt.completions.HuggingFaceModel.CODE_LLAMA_7B_Q3; import static org.assertj.core.api.Assertions.assertThat; import com.intellij.testFramework.fixtures.BasePlatformTestCase; +import ee.carlrobert.codegpt.completions.HuggingFaceModel; import ee.carlrobert.codegpt.conversations.Conversation; +import ee.carlrobert.codegpt.settings.service.ServiceType; public class SettingsStateTest extends BasePlatformTestCase { @@ -18,9 +21,7 @@ public class SettingsStateTest extends BasePlatformTestCase { settings.sync(conversation); - assertThat(settings) - .extracting("useOpenAIService", "useAzureService", "useYouService") - .containsExactly(true, false, false); + assertThat(settings.getSelectedService()).isEqualTo(ServiceType.OPENAI); assertThat(openAISettings.getModel()).isEqualTo("gpt-4"); } @@ -34,9 +35,7 @@ public class SettingsStateTest extends BasePlatformTestCase { settings.sync(conversation); - assertThat(settings) - .extracting("useOpenAIService", "useAzureService", "useYouService") - .containsExactly(false, true, false); + assertThat(settings.getSelectedService()).isEqualTo(ServiceType.AZURE); assertThat(azureSettings.getModel()).isEqualTo("gpt-4"); } @@ -48,8 +47,36 @@ public class SettingsStateTest extends BasePlatformTestCase { settings.sync(conversation); - assertThat(settings) - .extracting("useOpenAIService", "useAzureService", "useYouService") - .containsExactly(false, false, true); + assertThat(settings.getSelectedService()).isEqualTo(ServiceType.YOU); + } + + public void testLlamaSettingsModelPathSync() { + var settings = SettingsState.getInstance(); + var llamaSettings = LlamaSettingsState.getInstance(); + llamaSettings.setHuggingFaceModel(HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3); + var conversation = new Conversation(); + conversation.setModel("TEST_LLAMA_MODEL_PATH"); + conversation.setClientCode("llama.chat.completion"); + + settings.sync(conversation); + + assertThat(settings.getSelectedService()).isEqualTo(ServiceType.LLAMA_CPP); + assertThat(llamaSettings.getCustomLlamaModelPath()).isEqualTo("TEST_LLAMA_MODEL_PATH"); + assertThat(llamaSettings.isUseCustomModel()).isTrue(); + } + + public void testLlamaSettingsHuggingFaceModelSync() { + var settings = SettingsState.getInstance(); + var llamaSettings = LlamaSettingsState.getInstance(); + llamaSettings.setHuggingFaceModel(HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3); + var conversation = new Conversation(); + conversation.setModel("CODE_LLAMA_7B_Q3"); + conversation.setClientCode("llama.chat.completion"); + + settings.sync(conversation); + + assertThat(settings.getSelectedService()).isEqualTo(ServiceType.LLAMA_CPP); + assertThat(llamaSettings.getHuggingFaceModel()).isEqualTo(CODE_LLAMA_7B_Q3); + assertThat(llamaSettings.isUseCustomModel()).isFalse(); } }