From 36ba3616ded862f2f25747784e598dec6dbfb532 Mon Sep 17 00:00:00 2001 From: GenericMale Date: Wed, 11 Sep 2024 12:33:15 +0000 Subject: [PATCH] feat(ui): improve model combobox menu UI/UX by extracting the providers into standalone subgroups (#681) * feat: move models into sub-menus * feat: improve chat model combobox menu UI --------- Co-authored-by: GenericMale Co-authored-by: Carl-Robert Linnupuu --- .../chat/ui/textarea/ModelComboBoxAction.java | 158 +++++++----------- .../service/codegpt/CodeGPTAvailableModels.kt | 46 ++--- .../service/codegpt/CodeGPTServiceForm.kt | 6 +- 3 files changed, 84 insertions(+), 126 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java index da1a093c..1181c5b6 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java @@ -28,9 +28,11 @@ import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTAvailableModels; import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTModel; import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings; import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings; +import ee.carlrobert.codegpt.settings.service.google.GoogleSettings; import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; +import ee.carlrobert.llm.client.google.models.GoogleModel; import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel; import java.util.Arrays; import java.util.List; @@ -93,21 +95,24 @@ public class ModelComboBoxAction extends ComboBoxAction { if (availableProviders.contains(CODEGPT)) { actionGroup.addSeparator("CodeGPT"); actionGroup.addAll(getCodeGPTModelActions(project, presentation)); + actionGroup.addSeparator(); } + actionGroup.addSeparator("Cloud Providers"); if (availableProviders.contains(OPENAI)) { - actionGroup.addSeparator("OpenAI"); + var openaiGroup = DefaultActionGroup.createPopupGroup(() -> "OpenAI"); + openaiGroup.getTemplatePresentation().setIcon(Icons.OpenAI); List.of( OpenAIChatCompletionModel.GPT_4_O, OpenAIChatCompletionModel.GPT_4_O_MINI, OpenAIChatCompletionModel.GPT_4_VISION_PREVIEW, OpenAIChatCompletionModel.GPT_4_0125_128k) - .forEach(model -> actionGroup.add(createOpenAIModelAction(model, presentation))); + .forEach(model -> openaiGroup.add(createOpenAIModelAction(model, presentation))); + actionGroup.add(openaiGroup); } if (availableProviders.contains(CUSTOM_OPENAI)) { - actionGroup.addSeparator("Custom OpenAI"); actionGroup.add(createModelAction( CUSTOM_OPENAI, - ApplicationManager.getApplication().getService(CustomServiceSettings.class) + "Custom: " + ApplicationManager.getApplication().getService(CustomServiceSettings.class) .getState() .getTemplate() .getProviderName(), @@ -115,7 +120,6 @@ public class ModelComboBoxAction extends ComboBoxAction { presentation)); } if (availableProviders.contains(ANTHROPIC)) { - actionGroup.addSeparator("Anthropic"); actionGroup.add(createModelAction( ANTHROPIC, "Anthropic (Claude)", @@ -123,20 +127,19 @@ public class ModelComboBoxAction extends ComboBoxAction { presentation)); } if (availableProviders.contains(AZURE)) { - actionGroup.addSeparator("Azure"); actionGroup.add( createModelAction(AZURE, "Azure OpenAI", Icons.Azure, presentation)); } if (availableProviders.contains(GOOGLE)) { - actionGroup.addSeparator("Google"); - actionGroup.add(createModelAction( - GOOGLE, - "Google (Gemini)", - Icons.Google, - presentation)); + var googleGroup = DefaultActionGroup.createPopupGroup(() -> "Google (Gemini)"); + googleGroup.getTemplatePresentation().setIcon(Icons.Google); + Arrays.stream(GoogleModel.values()) + .forEach(model -> + googleGroup.add(createGoogleModelAction(model, presentation))); + actionGroup.add(googleGroup); } if (availableProviders.contains(LLAMA_CPP)) { - actionGroup.addSeparator("LLaMA C/C++"); + actionGroup.addSeparator("Local Providers"); actionGroup.add(createModelAction( LLAMA_CPP, getLlamaCppPresentationText(), @@ -144,14 +147,15 @@ public class ModelComboBoxAction extends ComboBoxAction { presentation)); } if (availableProviders.contains(OLLAMA)) { - actionGroup.addSeparator("Ollama"); - createOllamaModelAction("Default model", presentation); + var ollamaGroup = DefaultActionGroup.createPopupGroup(() -> "Ollama"); + ollamaGroup.getTemplatePresentation().setIcon(Icons.Ollama); ApplicationManager.getApplication() .getService(OllamaSettings.class) .getState() .getAvailableModels() .forEach(model -> - actionGroup.add(createOllamaModelAction(model, presentation))); + ollamaGroup.add(createOllamaModelAction(model, presentation))); + actionGroup.add(ollamaGroup); } return actionGroup; @@ -235,11 +239,20 @@ public class ModelComboBoxAction extends ComboBoxAction { huggingFaceModel.getQuantization()); } + private AnAction createModelAction( + ServiceType serviceType, + String label, + Icon icon, + Presentation comboBoxPresentation) { + return createModelAction(serviceType, label, icon, comboBoxPresentation, null); + } + private AnAction createModelAction( ServiceType serviceType, String label, Icon icon, - Presentation comboBoxPresentation) { + Presentation comboBoxPresentation, + Runnable onModelChanged) { return new DumbAwareAction(label, "", icon) { @Override public void update(@NotNull AnActionEvent event) { @@ -249,7 +262,10 @@ public class ModelComboBoxAction extends ComboBoxAction { @Override public void actionPerformed(@NotNull AnActionEvent e) { - handleModelChange(serviceType, label, label, icon, comboBoxPresentation); + if (onModelChanged != null) { + onModelChanged.run(); + } + handleModelChange(serviceType, label, icon, comboBoxPresentation); } @Override @@ -262,7 +278,6 @@ public class ModelComboBoxAction extends ComboBoxAction { private void handleModelChange( ServiceType serviceType, String label, - String code, Icon icon, Presentation comboBoxPresentation) { GeneralSettings.getCurrentState().setSelectedService(serviceType); @@ -272,93 +287,34 @@ public class ModelComboBoxAction extends ComboBoxAction { } private AnAction createCodeGPTModelAction(CodeGPTModel model, Presentation comboBoxPresentation) { - return new DumbAwareAction(model.getName(), "", model.getIcon()) { - @Override - public void update(@NotNull AnActionEvent event) { - var presentation = event.getPresentation(); - presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText())); - } - - @Override - public void actionPerformed(@NotNull AnActionEvent e) { - ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class) - .getState() - .getChatCompletionSettings() - .setModel(model.getCode()); - handleModelChange( - CODEGPT, - model.getName(), - model.getCode(), - model.getIcon(), - comboBoxPresentation); - } - - @Override - public @NotNull ActionUpdateThread getActionUpdateThread() { - return ActionUpdateThread.BGT; - } - }; + return createModelAction(CODEGPT, model.getName(), model.getIcon(), comboBoxPresentation, + () -> ApplicationManager.getApplication() + .getService(CodeGPTServiceSettings.class) + .getState() + .getChatCompletionSettings() + .setModel(model.getCode())); } - private AnAction createOllamaModelAction( - String model, - Presentation comboBoxPresentation - ) { - return new DumbAwareAction(model, "", Icons.Ollama) { - @Override - public void update(@NotNull AnActionEvent event) { - var presentation = event.getPresentation(); - presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText())); - } - - @Override - public void actionPerformed(@NotNull AnActionEvent e) { - ApplicationManager.getApplication() - .getService(OllamaSettings.class) - .getState() - .setModel(model); - handleModelChange( - OLLAMA, - model, - model, - Icons.Ollama, - comboBoxPresentation); - } - - @Override - public @NotNull ActionUpdateThread getActionUpdateThread() { - return ActionUpdateThread.BGT; - } - }; + private AnAction createOllamaModelAction(String model, Presentation comboBoxPresentation) { + return createModelAction(OLLAMA, model, Icons.Ollama, comboBoxPresentation, + () -> ApplicationManager.getApplication() + .getService(OllamaSettings.class) + .getState() + .setModel(model)); } private AnAction createOpenAIModelAction( - OpenAIChatCompletionModel model, - Presentation comboBoxPresentation) { - createModelAction(OPENAI, model.getDescription(), Icons.OpenAI, - comboBoxPresentation); - return new DumbAwareAction(model.getDescription(), "", Icons.OpenAI) { - @Override - public void update(@NotNull AnActionEvent event) { - var presentation = event.getPresentation(); - presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText())); - } + OpenAIChatCompletionModel model, + Presentation comboBoxPresentation) { + return createModelAction(OPENAI, model.getDescription(), Icons.OpenAI, comboBoxPresentation, + () -> OpenAISettings.getCurrentState().setModel(model.getCode())); + } - @Override - public void actionPerformed(@NotNull AnActionEvent e) { - OpenAISettings.getCurrentState().setModel(model.getCode()); - handleModelChange( - OPENAI, - model.getDescription(), - model.getCode(), - Icons.OpenAI, - comboBoxPresentation); - } - - @Override - public @NotNull ActionUpdateThread getActionUpdateThread() { - return ActionUpdateThread.BGT; - } - }; + private AnAction createGoogleModelAction(GoogleModel model, Presentation comboBoxPresentation) { + return createModelAction(GOOGLE, model.getDescription(), Icons.Google, comboBoxPresentation, + () -> ApplicationManager.getApplication() + .getService(GoogleSettings.class) + .getState() + .setModel(model.getCode())); } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt index afab7bbe..bcbc9310 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt @@ -7,21 +7,20 @@ import javax.swing.Icon object CodeGPTAvailableModels { + val DEFAULT_CHAT_MODEL = CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL) + val DEFAULT_CODE_MODEL = CodeGPTModel("GPT-3.5 Turbo Instruct", "gpt-3.5-turbo-instruct", Icons.OpenAI, INDIVIDUAL) + @JvmStatic fun getToolWindowModels(pricingPlan: PricingPlan?): List { - val anonymousModels = listOf( - CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), - CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), - CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL), - CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), - CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), - CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS) - ) - if (pricingPlan == null) { - return anonymousModels - } return when (pricingPlan) { - ANONYMOUS -> anonymousModels + null, ANONYMOUS -> listOf( + CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), + CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL), + CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), + CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), + CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS) + ) FREE -> listOf( CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), @@ -32,12 +31,19 @@ object CodeGPTAvailableModels { CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE), ) - else -> BASE_CHAT_MODELS + INDIVIDUAL -> listOf( + CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL), + CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), + CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL), + CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), + CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL), + ) } } @JvmStatic - val BASE_CHAT_MODELS: List = listOf( + val ALL_CHAT_MODELS: List = listOf( CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL), @@ -46,10 +52,6 @@ object CodeGPTAvailableModels { CodeGPTModel("Llama 3 (70B)", "llama-3-70b", Icons.Meta, FREE), CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL), - ) - - @JvmStatic - val ALL_CHAT_MODELS: List = BASE_CHAT_MODELS + listOf( CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS), CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE), CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.CodeGPTModel, FREE), @@ -58,8 +60,8 @@ object CodeGPTAvailableModels { ) @JvmStatic - val CODE_MODELS: List = listOf( - CodeGPTModel("GPT-3.5 Turbo Instruct", "gpt-3.5-turbo-instruct", Icons.OpenAI, INDIVIDUAL), + val ALL_CODE_MODELS: List = listOf( + DEFAULT_CODE_MODEL, CodeGPTModel("StarCoder (16B)", "starcoder-16b", Icons.CodeGPTModel, FREE), CodeGPTModel("StarCoder (7B) - FREE", "starcoder-7b", Icons.CodeGPTModel, FREE), CodeGPTModel("WizardCoder Python (34B)", "wizardcoder-python", Icons.CodeGPTModel, FREE), @@ -68,7 +70,7 @@ object CodeGPTAvailableModels { @JvmStatic fun findByCode(code: String?): CodeGPTModel? { - return ALL_CHAT_MODELS.union(CODE_MODELS).firstOrNull { it.code == code } + return ALL_CHAT_MODELS.union(ALL_CODE_MODELS).firstOrNull { it.code == code } } } @@ -77,4 +79,4 @@ data class CodeGPTModel( val code: String, val icon: Icon, val individual: PricingPlan -) \ No newline at end of file +) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTServiceForm.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTServiceForm.kt index 07ded254..93d96b59 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTServiceForm.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTServiceForm.kt @@ -28,7 +28,7 @@ class CodeGPTServiceForm { ComboBox(ListComboBoxModel(CodeGPTAvailableModels.ALL_CHAT_MODELS)).apply { val chatModel = service().state.chatCompletionSettings.model selectedItem = CodeGPTAvailableModels.findByCode(chatModel) - ?: CodeGPTAvailableModels.BASE_CHAT_MODELS[0] + ?: CodeGPTAvailableModels.DEFAULT_CHAT_MODEL renderer = CustomComboBoxRenderer() } @@ -38,10 +38,10 @@ class CodeGPTServiceForm { ) private val codeCompletionModelComboBox = - ComboBox(ListComboBoxModel(CodeGPTAvailableModels.CODE_MODELS)).apply { + ComboBox(ListComboBoxModel(CodeGPTAvailableModels.ALL_CODE_MODELS)).apply { val codeModel = service().state.codeCompletionSettings.model selectedItem = CodeGPTAvailableModels.findByCode(codeModel) - ?: CodeGPTAvailableModels.CODE_MODELS[0] + ?: CodeGPTAvailableModels.DEFAULT_CODE_MODEL renderer = CustomComboBoxRenderer() }