From ec3120a5e6403ee6aeb7bbfb819ab3e0174cdd3e Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Tue, 14 Nov 2023 13:27:15 +0200 Subject: [PATCH] Add interactive total token count label, codebase refactoring --- build.gradle.kts | 3 +- .../codegpt.java-conventions.gradle.kts | 2 +- .../carlrobert/codegpt/EncodingManager.java | 17 +- .../java/ee/carlrobert/codegpt/Icons.java | 7 +- .../codegpt/actions/editor/AskAction.java | 7 +- .../CompletionRequestProvider.java | 16 +- .../configuration/ConfigurationComponent.java | 7 + .../configuration/ConfigurationState.java | 4 +- .../settings/state/LlamaSettingsState.java | 6 +- .../chat/BaseChatToolWindowTabPanel.java | 314 +++++++++--------- .../chat/ChatToolWindowScrollablePanel.java | 91 +++++ .../chat/ChatToolWindowTabPanel.java | 4 + .../codegpt/toolwindow/chat/TokenDetails.java | 49 +++ .../components/ChatMessageResponseBody.java | 18 +- .../chat/components/TotalTokensPanel.java | 105 ++++++ .../chat/components/UserPromptTextArea.java | 35 +- .../components/UserPromptTextAreaHeader.java | 77 +++++ .../StandardChatToolWindowTabPanel.java | 9 +- .../StandardChatToolWindowTabbedPane.java | 1 + .../carlrobert/codegpt/util/EditorUtils.java | 9 + .../carlrobert/codegpt/util/SwingUtils.java | 12 + src/main/resources/META-INF/plugin.xml | 1 + src/main/resources/icons/sparkle.svg | 24 ++ .../CompletionRequestProviderTest.java | 94 +++--- .../DefaultCompletionRequestHandlerTest.java | 73 +--- .../settings/state/SettingsStateTest.java | 1 - .../StandardChatToolWindowTabPanelTest.java | 78 +++++ .../StandardChatToolWindowTabbedPaneTest.java | 4 +- .../java/testsupport/IntegrationTest.java | 20 ++ .../testsupport/mixin/ShortcutsTestMixin.java | 38 +++ src/test/resources/application.properties | 0 31 files changed, 804 insertions(+), 322 deletions(-) create mode 100644 src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowScrollablePanel.java create mode 100644 src/main/java/ee/carlrobert/codegpt/toolwindow/chat/TokenDetails.java create mode 100644 src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/TotalTokensPanel.java create mode 100644 src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextAreaHeader.java create mode 100644 src/main/resources/icons/sparkle.svg create mode 100644 src/test/java/ee/carlrobert/codegpt/toolwindow/chat/StandardChatToolWindowTabPanelTest.java create mode 100644 src/test/java/testsupport/IntegrationTest.java create mode 100644 src/test/java/testsupport/mixin/ShortcutsTestMixin.java create mode 100644 src/test/resources/application.properties diff --git a/build.gradle.kts b/build.gradle.kts index 47d76461..29678a54 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -154,9 +154,10 @@ tasks { } test { + exclude("**/testsupport/*") useJUnitPlatform() testLogging { - events("passed", "skipped", "failed") + events("started", "passed", "skipped", "failed") exceptionFormat = TestExceptionFormat.FULL showStandardStreams = true } diff --git a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts index 4a2fd5b7..cd0083c4 100644 --- a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts @@ -18,7 +18,7 @@ intellij { } dependencies { - implementation("ee.carlrobert:llm-client:0.0.9") + implementation("ee.carlrobert:llm-client:0.0.10") } tasks { diff --git a/src/main/java/ee/carlrobert/codegpt/EncodingManager.java b/src/main/java/ee/carlrobert/codegpt/EncodingManager.java index 0d21525d..c0cddfe7 100644 --- a/src/main/java/ee/carlrobert/codegpt/EncodingManager.java +++ b/src/main/java/ee/carlrobert/codegpt/EncodingManager.java @@ -6,6 +6,7 @@ import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.EncodingRegistry; import com.knuddels.jtokkit.api.EncodingType; +import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.llm.client.openai.completion.chat.request.OpenAIChatCompletionMessage; @Service @@ -21,9 +22,23 @@ public final class EncodingManager { return ApplicationManager.getApplication().getService(EncodingManager.class); } + public int countConversationTokens(Conversation conversation) { + if (conversation != null) { + return conversation.getMessages().stream() + .mapToInt( + message -> countTokens(message.getPrompt()) + countTokens(message.getResponse())) + .sum(); + } + return 0; + } + public int countMessageTokens(OpenAIChatCompletionMessage message) { + return countMessageTokens(message.getRole(), message.getContent()); + } + + public int countMessageTokens(String role, String content) { var tokensPerMessage = 4; // every message follows <|start|>{role/name}\n{content}<|end|>\n - return encoding.countTokens(message.getRole() + message.getContent()) + tokensPerMessage; + return countTokens(role + content) + tokensPerMessage; } public int countTokens(String text) { diff --git a/src/main/java/ee/carlrobert/codegpt/Icons.java b/src/main/java/ee/carlrobert/codegpt/Icons.java index 3eef6b2a..6f8e2188 100644 --- a/src/main/java/ee/carlrobert/codegpt/Icons.java +++ b/src/main/java/ee/carlrobert/codegpt/Icons.java @@ -7,9 +7,10 @@ public final class Icons { public static final Icon DefaultIcon = IconLoader.getIcon("/icons/codegpt.svg", Icons.class); public static final Icon DefaultSmallIcon = IconLoader.getIcon("/icons/codegpt-small.svg", Icons.class); - public static final Icon SendIcon = IconLoader.getIcon("/icons/send.svg", Icons.class); - public static final Icon OpenAIIcon = IconLoader.getIcon("/icons/openai.svg", Icons.class); public static final Icon AzureIcon = IconLoader.getIcon("/icons/azure.svg", Icons.class); - public static final Icon YouIcon = IconLoader.getIcon("/icons/you.svg", Icons.class); public static final Icon LlamaIcon = IconLoader.getIcon("/icons/llama.svg", Icons.class); + public static final Icon OpenAIIcon = IconLoader.getIcon("/icons/openai.svg", Icons.class); + public static final Icon SendIcon = IconLoader.getIcon("/icons/send.svg", Icons.class); + public static final Icon SparkleIcon = IconLoader.getIcon("/icons/sparkle.svg", Icons.class); + public static final Icon YouIcon = IconLoader.getIcon("/icons/you.svg", Icons.class); } diff --git a/src/main/java/ee/carlrobert/codegpt/actions/editor/AskAction.java b/src/main/java/ee/carlrobert/codegpt/actions/editor/AskAction.java index 9adb9004..527cc1f9 100644 --- a/src/main/java/ee/carlrobert/codegpt/actions/editor/AskAction.java +++ b/src/main/java/ee/carlrobert/codegpt/actions/editor/AskAction.java @@ -1,8 +1,8 @@ package ee.carlrobert.codegpt.actions.editor; -import com.intellij.icons.AllIcons; import com.intellij.openapi.actionSystem.AnAction; import com.intellij.openapi.actionSystem.AnActionEvent; +import ee.carlrobert.codegpt.Icons; import ee.carlrobert.codegpt.conversations.ConversationsState; import ee.carlrobert.codegpt.toolwindow.chat.standard.StandardChatToolWindowContentManager; import org.jetbrains.annotations.NotNull; @@ -10,7 +10,7 @@ import org.jetbrains.annotations.NotNull; public class AskAction extends AnAction { public AskAction() { - super("New Chat", "Chat with CodeGPT", AllIcons.Actions.Find); + super("New Chat", "Chat with CodeGPT", Icons.SparkleIcon); EditorActionsUtil.registerOrReplaceAction(this); } @@ -24,7 +24,8 @@ public class AskAction extends AnAction { var project = event.getProject(); if (project != null) { ConversationsState.getInstance().setCurrentConversation(null); - var tabPanel = project.getService(StandardChatToolWindowContentManager.class).createNewTabPanel(); + var tabPanel = + project.getService(StandardChatToolWindowContentManager.class).createNewTabPanel(); if (tabPanel != null) { tabPanel.displayLandingView(); } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index 249bb686..6c2586fa 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -17,7 +17,6 @@ import ee.carlrobert.codegpt.settings.state.SettingsState; import ee.carlrobert.codegpt.settings.state.YouSettingsState; import ee.carlrobert.codegpt.telemetry.core.configuration.TelemetryConfiguration; import ee.carlrobert.codegpt.telemetry.core.service.UserId; -import ee.carlrobert.codegpt.util.ApplicationUtils; import ee.carlrobert.embedding.EmbeddingsService; import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; import ee.carlrobert.llm.client.openai.completion.chat.OpenAIChatCompletionModel; @@ -37,13 +36,12 @@ public class CompletionRequestProvider { private static final Logger LOG = Logger.getInstance(CompletionRequestProvider.class); public static final String COMPLETION_SYSTEM_PROMPT = "You are an AI programming assistant.\n" + - "When asked for you name, you must respond with \"CodeGPT\".\n" + "Follow the user's requirements carefully & to the letter.\n" + "Your responses should be informative and logical.\n" + "You should always adhere to technical information.\n" + "If the user asks for code or technical questions, you must provide code suggestions and " + "adhere to technical information.\n" + - "If the question is related to a developer, CodeGPT must respond with " + + "If the question is related to a developer, you must respond with " + "content related to a developer.\n" + "First think step-by-step - describe your plan for what to build in pseudocode, " + "written out in great detail.\n" + @@ -125,8 +123,7 @@ public class CompletionRequestProvider { return (OpenAIChatCompletionRequest) builder.build(); } - private List buildMessages( - @Nullable String model, + public List buildMessages( Message message, boolean isRetry, boolean useContextualSearch) { @@ -149,6 +146,15 @@ public class CompletionRequestProvider { } messages.add(new OpenAIChatCompletionMessage("user", message.getPrompt())); } + return messages; + } + + private List buildMessages( + @Nullable String model, + Message message, + boolean isRetry, + boolean useContextualSearch) { + var messages = buildMessages(message, isRetry, useContextualSearch); if (model == null || SettingsState.getInstance().getSelectedService() == ServiceType.YOU) { return messages; diff --git a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java index 7f0bc16a..f2d9da7f 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java @@ -1,6 +1,7 @@ package ee.carlrobert.codegpt.settings.configuration; import static ee.carlrobert.codegpt.actions.editor.EditorActionsUtil.DEFAULT_ACTIONS_ARRAY; +import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT; import com.intellij.icons.AllIcons; import com.intellij.icons.AllIcons.Nodes; @@ -87,6 +88,12 @@ public class ConfigurationComponent { maxTokensField.setValue(configuration.getMaxTokens()); systemPromptTextArea = new JTextArea(); + if (configuration.getSystemPrompt().isEmpty()) { + // for backward compatibility + systemPromptTextArea.setText(COMPLETION_SYSTEM_PROMPT); + } else { + systemPromptTextArea.setText(configuration.getSystemPrompt()); + } systemPromptTextArea.setLineWrap(true); systemPromptTextArea.setBorder(JBUI.Borders.empty(8, 4)); systemPromptTextArea.setColumns(60); diff --git a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationState.java b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationState.java index f5bee3ea..75e45d97 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationState.java @@ -1,5 +1,7 @@ package ee.carlrobert.codegpt.settings.configuration; +import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT; + import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.PersistentStateComponent; import com.intellij.openapi.components.State; @@ -13,7 +15,7 @@ import org.jetbrains.annotations.Nullable; @State(name = "CodeGPT_ConfigurationSettings_210", storages = @Storage("CodeGPT_ConfigurationSettings_210.xml")) public class ConfigurationState implements PersistentStateComponent { - private String systemPrompt = ""; + private String systemPrompt = COMPLETION_SYSTEM_PROMPT; private int maxTokens = 1000; private double temperature = 0.2; private boolean createNewChatOnEachAction; diff --git a/src/main/java/ee/carlrobert/codegpt/settings/state/LlamaSettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/state/LlamaSettingsState.java index 5b505329..5abdb9b2 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/state/LlamaSettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/state/LlamaSettingsState.java @@ -18,7 +18,7 @@ public class LlamaSettingsState implements PersistentStateComponent visibleMessagePanels = new HashMap<>(); private final Map> serpResultsMapping = new HashMap<>(); private final JBCheckBox gpt4CheckBox; - + protected final TotalTokensPanel totalTokensPanel; protected final Project project; protected final UserPromptTextArea userPromptTextArea; protected final ConversationService conversationService; + protected final ChatToolWindowScrollablePanel toolWindowScrollablePanel; + private final EncodingManager encodingManager; + + private boolean streaming; protected @Nullable Conversation conversation; protected abstract JComponent getLandingView(); @@ -86,22 +81,24 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan public BaseChatToolWindowTabPanel(@NotNull Project project, boolean useContextualSearch) { this.project = project; this.useContextualSearch = useContextualSearch; - this.conversationService = ConversationService.getInstance(); - this.scrollablePanel = new ScrollablePanel(new VerticalStackLayout()); - this.userPromptTextArea = new UserPromptTextArea(this::handleSubmit); - this.gpt4CheckBox = new YouProCheckbox(project); - this.settings = SettingsState.getInstance(); - this.youUserManager = YouUserManager.getInstance(); - this.rootPanel = createRootPanel(); + conversationService = ConversationService.getInstance(); + encodingManager = EncodingManager.getInstance(); + settings = SettingsState.getInstance(); + toolWindowScrollablePanel = new ChatToolWindowScrollablePanel(settings); + gpt4CheckBox = new YouProCheckbox(project); + userPromptTextArea = new UserPromptTextArea(this::handleSubmit, getUserPromptDocumentAdapter()); + totalTokensPanel = new TotalTokensPanel( + null, + userPromptTextArea.getText(), + EditorUtils.getSelectedEditorSelectedText(project)); + rootPanel = createRootPanel(); + + addSelectionListeners(); userPromptTextArea.requestFocusInWindow(); userPromptTextArea.requestFocus(); } - public void requestFocusForTextArea() { - userPromptTextArea.focus(); - } - @Override public JPanel getContent() { return rootPanel; @@ -119,13 +116,7 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan @Override public void displayLandingView() { - scrollablePanel.removeAll(); - scrollablePanel.add(getLandingView()); - if (settings.getSelectedService() == ServiceType.YOU && - (!youUserManager.isAuthenticated() || !youUserManager.isSubscribed())) { - scrollablePanel.add(new ResponsePanel().addContent(createYouCouponTextPane())); - } - revalidateScrollablePanel(); + toolWindowScrollablePanel.displayLandingView(getLandingView()); } @Override @@ -136,36 +127,53 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan @Override public void sendMessage(Message message) { + streaming = true; if (conversation == null) { conversation = conversationService.startConversation(); } - var messageWrapper = createNewMessageWrapper(message.getId()); - messageWrapper.add(new UserMessagePanel(project, message, this)); + var messagePanel = toolWindowScrollablePanel.addMessage(message.getId()); + messagePanel.add(new UserMessagePanel(project, message, this)); var responsePanel = new ResponsePanel() .withReloadAction(() -> reloadMessage(message, conversation)) - .withDeleteAction(() -> removeMessage(message.getId(), messageWrapper, conversation)) + .withDeleteAction(() -> removeMessage(message.getId(), conversation)) .addContent(new ChatMessageResponseBody(project, true, this)); - messageWrapper.add(responsePanel); + messagePanel.add(responsePanel); + + totalTokensPanel.updateUserPromptTokens(message.getPrompt()); call(conversation, message, responsePanel, false); } + @Override + public TokenDetails getTokenDetails() { + return totalTokensPanel.getTokenDetails(); + } + @Override public void dispose() { } + public void requestFocusForTextArea() { + userPromptTextArea.focus(); + } + + public void updateConversationTokens() { + totalTokensPanel.updateConversationTokens(conversation); + } + + public boolean isStreaming() { + return streaming; + } + protected void reloadMessage(Message message, Conversation conversation) { ResponsePanel responsePanel = null; try { - responsePanel = (ResponsePanel) Arrays.stream( - visibleMessagePanels.get(message.getId()).getComponents()) - .filter(component -> component instanceof ResponsePanel) - .findFirst().orElseThrow(); + responsePanel = toolWindowScrollablePanel.getMessageResponsePanel(message.getId()); ((ChatMessageResponseBody) responsePanel.getContent()).clear(); - revalidateScrollablePanel(); + toolWindowScrollablePanel.update(); } catch (Exception e) { - throw new RuntimeException("Couldn't delete the existing message component", e); + throw new RuntimeException("Could not delete the existing message component", e); } finally { LOG.debug("Reloading message: " + message.getId()); @@ -175,22 +183,16 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan call(conversation, message, responsePanel, true); } + totalTokensPanel.updateConversationTokens(conversation); + TelemetryAction.IDE_ACTION.createActionMessage() .property("action", ActionType.RELOAD_MESSAGE.name()) .send(); } } - private void revalidateScrollablePanel() { - scrollablePanel.repaint(); - scrollablePanel.revalidate(); - } - - protected void removeMessage(UUID messageId, JPanel messageWrapper, Conversation conversation) { - scrollablePanel.remove(messageWrapper); - revalidateScrollablePanel(); - - visibleMessagePanels.remove(messageId); + protected void removeMessage(UUID messageId, Conversation conversation) { + toolWindowScrollablePanel.removeMessage(messageId); conversation.removeMessage(messageId); conversationService.saveConversation(conversation); @@ -201,18 +203,10 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan } } - protected JPanel createNewMessageWrapper(UUID messageId) { - var messageWrapper = new JPanel(); - messageWrapper.setLayout(new BoxLayout(messageWrapper, BoxLayout.PAGE_AXIS)); - scrollablePanel.add(messageWrapper); - revalidateScrollablePanel(); - visibleMessagePanels.put(messageId, messageWrapper); - return messageWrapper; - } - protected void clearWindow() { - scrollablePanel.removeAll(); - revalidateScrollablePanel(); + toolWindowScrollablePanel.clearAll(); + totalTokensPanel.updateConversationTokens(conversation); + updateConversationTokens(); } private void call( @@ -257,11 +251,17 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan } } - private static JScrollPane createScrollPane(ScrollablePanel scrollablePanel) { - var scrollPane = ScrollPaneFactory.createScrollPane(scrollablePanel, true); - scrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER); - new SmartScroller(scrollPane); - return scrollPane; + private JPanel createUserPromptPanel() { + var panel = new JPanel(new BorderLayout()); + panel.setBorder(JBUI.Borders.compound( + JBUI.Borders.customLine(JBColor.border(), 1, 0, 0, 0), + JBUI.Borders.empty(8))); + panel.setBackground(getPanelBackgroundColor()); + panel.add( + new UserPromptTextAreaHeader(project, settings, totalTokensPanel, gpt4CheckBox), + BorderLayout.NORTH); + panel.add(userPromptTextArea, BorderLayout.SOUTH); + return panel; } private JPanel createRootPanel() { @@ -272,88 +272,20 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan gbc.weightx = 1; gbc.gridx = 0; gbc.gridy = 0; - rootPanel.add(createScrollPane(scrollablePanel), gbc); - - var wrapper = new JPanel(new BorderLayout()); - wrapper.setBorder(JBUI.Borders.compound( - JBUI.Borders.customLine(JBColor.border(), 1, 0, 0, 0), - JBUI.Borders.empty(8))); - wrapper.setBackground(getPanelBackgroundColor()); - wrapper.add(createPromptTextAreaHeader(), BorderLayout.NORTH); - wrapper.add(userPromptTextArea, BorderLayout.SOUTH); + rootPanel.add(createScrollPaneWithSmartScroller(toolWindowScrollablePanel), gbc); gbc.weighty = 0; gbc.fill = GridBagConstraints.HORIZONTAL; gbc.gridy = 1; - rootPanel.add(wrapper, gbc); - + rootPanel.add(createUserPromptPanel(), gbc); return rootPanel; } - private JPanel createPromptTextAreaHeader() { - var header = new JPanel(new BorderLayout()); - header.setBackground(getPanelBackgroundColor()); - header.setBorder(JBUI.Borders.emptyBottom(8)); - var model = settings.getModel(); - if ("YouCode".equals(model)) { - var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect(); - subscribeToYouModelChangeTopic(); - subscribeToYouSubscriptionTopic(messageBusConnection); - subscribeToSignedOutTopic(messageBusConnection); - header.add(gpt4CheckBox, BorderLayout.LINE_START); - } - header.add(JBUI.Panels - .simplePanel( - new ModelIconLabel(settings.getSelectedService().getCompletionCode(), - model)) - .withBorder(Borders.emptyRight(4)) - .withBackground(getPanelBackgroundColor()), BorderLayout.LINE_END); - return header; - } - - private void subscribeToYouModelChangeTopic() { - project.getMessageBus() - .connect() - .subscribe( - YouModelChangeNotifier.YOU_MODEL_CHANGE_NOTIFIER_TOPIC, - (YouModelChangeNotifier) gpt4CheckBox::setSelected); - } - - private void subscribeToSignedOutTopic(MessageBusConnection messageBusConnection) { - messageBusConnection.subscribe( - SignedOutNotifier.SIGNED_OUT_TOPIC, - (SignedOutNotifier) () -> gpt4CheckBox.setEnabled(false)); - } - - private void subscribeToYouSubscriptionTopic(MessageBusConnection messageBusConnection) { - messageBusConnection.subscribe( - YouSubscriptionNotifier.SUBSCRIPTION_TOPIC, - (YouSubscriptionNotifier) () -> { - displayLandingView(); - gpt4CheckBox.setEnabled(true); - }); - } - - private JTextPane createYouCouponTextPane() { - var textPane = SwingUtils.createTextPane( - "\n" - + "\n" - + "

Use CodeGPT coupon for free month of GPT-4.

\n" - + "

\n" - + " Sign up here\n" - + "

\n" - + "\n" - + "" - ); - textPane.setBackground(getPanelBackgroundColor()); - textPane.setFocusable(false); - return textPane; - } - private class ChatToolWindowCompletionEventListener implements ToolWindowCompletionEventListener { private final Logger LOG = Logger.getInstance(ChatToolWindowCompletionEventListener.class); + private final StringBuilder messageBuilder = new StringBuilder(); private final ResponsePanel responsePanel; private final ChatMessageResponseBody responseContainer; @@ -363,11 +295,18 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan } @Override - public void handleMessage(String message) { + public void handleMessage(String partialMessage) { try { - LOG.debug(message); + LOG.debug(partialMessage); ApplicationManager.getApplication() - .invokeLater(() -> responseContainer.update(message)); + .invokeLater(() -> { + responseContainer.update(partialMessage); + messageBuilder.append(partialMessage); + + var ongoingTokens = encodingManager.countTokens(messageBuilder.toString()); + totalTokensPanel.update( + totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens); + }); } catch (Exception e) { responseContainer.displayDefaultError(); throw new RuntimeException("Error while updating the content", e); @@ -414,18 +353,24 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan Message message, Conversation conversation, boolean isRetry) { - responsePanel.enableActions(); - conversationService.saveMessage(fullMessage, message, conversation, isRetry); - stopStreaming(responseContainer); + try { + responsePanel.enableActions(); + conversationService.saveMessage(fullMessage, message, conversation, isRetry); - var serpResults = serpResultsMapping.get(message.getId()); - var containsResults = serpResults != null && !serpResults.isEmpty(); - if (YouSettingsState.getInstance().isDisplayWebSearchResults() && containsResults) { - responseContainer.displaySerpResults(serpResults); - } + var serpResults = serpResultsMapping.get(message.getId()); + var containsResults = serpResults != null && !serpResults.isEmpty(); + if (YouSettingsState.getInstance().isDisplayWebSearchResults() && containsResults) { + responseContainer.displaySerpResults(serpResults); + } - if (containsResults) { - message.setSerpResults(serpResults); + if (containsResults) { + message.setSerpResults(serpResults); + } + + totalTokensPanel.updateUserPromptTokens(userPromptTextArea.getText()); + totalTokensPanel.updateConversationTokens(conversation); + } finally { + stopStreaming(responseContainer); } } @@ -435,10 +380,51 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan } private void stopStreaming(ChatMessageResponseBody responseContainer) { - SwingUtilities.invokeLater(() -> { - userPromptTextArea.setSubmitEnabled(true); - responseContainer.hideCarets(); - }); + streaming = false; + userPromptTextArea.setSubmitEnabled(true); + responseContainer.hideCarets(); } } + + private void addSelectionListeners() { + var editorFactory = EditorFactory.getInstance(); + for (var editor : editorFactory.getAllEditors()) { + editor.getSelectionModel().addSelectionListener(getSelectionListener()); + } + editorFactory.addEditorFactoryListener(new EditorFactoryListener() { + @Override + public void editorCreated(@NotNull EditorFactoryEvent event) { + event.getEditor().getSelectionModel().addSelectionListener(getSelectionListener()); + } + }, this); + } + + private SelectionListener getSelectionListener() { + return new SelectionListener() { + @Override + public void selectionChanged(@NotNull SelectionEvent e) { + var selectedText = e.getEditor().getDocument().getText(e.getNewRange()); + totalTokensPanel.updateHighlightedTokens(selectedText); + } + }; + } + + private DocumentAdapter getUserPromptDocumentAdapter() { + return new DocumentAdapter() { + @Override + protected void textChanged(@NotNull DocumentEvent event) { + try { + if (!streaming) { + var document = event.getDocument(); + var text = document.getText( + document.getStartPosition().getOffset(), + document.getEndPosition().getOffset() - 1); + totalTokensPanel.updateUserPromptTokens(text); + } + } catch (BadLocationException ex) { + LOG.error("Something went wrong while processing user input tokens", ex); + } + } + }; + } } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowScrollablePanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowScrollablePanel.java new file mode 100644 index 00000000..41c04fd8 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowScrollablePanel.java @@ -0,0 +1,91 @@ +package ee.carlrobert.codegpt.toolwindow.chat; + +import static ee.carlrobert.codegpt.util.ThemeUtils.getPanelBackgroundColor; + +import com.intellij.openapi.roots.ui.componentsList.components.ScrollablePanel; +import com.intellij.openapi.roots.ui.componentsList.layout.VerticalStackLayout; +import ee.carlrobert.codegpt.completions.you.YouUserManager; +import ee.carlrobert.codegpt.settings.service.ServiceType; +import ee.carlrobert.codegpt.settings.state.SettingsState; +import ee.carlrobert.codegpt.toolwindow.chat.components.ResponsePanel; +import ee.carlrobert.codegpt.util.SwingUtils; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import javax.swing.BoxLayout; +import javax.swing.JComponent; +import javax.swing.JPanel; +import javax.swing.JTextPane; + +public class ChatToolWindowScrollablePanel extends ScrollablePanel { + + private final SettingsState settings; + private final YouUserManager youUserManager; + private final Map visibleMessagePanels; + + ChatToolWindowScrollablePanel(SettingsState settings) { + super(new VerticalStackLayout()); + this.settings = settings; + this.youUserManager = YouUserManager.getInstance(); + this.visibleMessagePanels = new HashMap<>(); + } + + public void displayLandingView(JComponent landingView) { + removeAll(); + add(landingView); + if (settings.getSelectedService() == ServiceType.YOU && + (!youUserManager.isAuthenticated() || !youUserManager.isSubscribed())) { + add(new ResponsePanel().addContent(createYouCouponTextPane())); + } + update(); + } + + public ResponsePanel getMessageResponsePanel(UUID messageId) { + return (ResponsePanel) Arrays.stream(visibleMessagePanels.get(messageId).getComponents()) + .filter(component -> component instanceof ResponsePanel) + .findFirst().orElseThrow(); + } + + public JPanel addMessage(UUID messageId) { + var messageWrapper = new JPanel(); + messageWrapper.setLayout(new BoxLayout(messageWrapper, BoxLayout.PAGE_AXIS)); + add(messageWrapper); + visibleMessagePanels.put(messageId, messageWrapper); + return messageWrapper; + } + + public void removeMessage(UUID messageId) { + remove(visibleMessagePanels.get(messageId)); + update(); + visibleMessagePanels.remove(messageId); + } + + public void clearAll() { + visibleMessagePanels.clear(); + removeAll(); + update(); + } + + public void update() { + repaint(); + revalidate(); + } + + // TODO: Move + private JTextPane createYouCouponTextPane() { + var textPane = SwingUtils.createTextPane( + "\n" + + "\n" + + "

Use CodeGPT coupon for free month of GPT-4.

\n" + + "

\n" + + " Sign up here\n" + + "

\n" + + "\n" + + "" + ); + textPane.setBackground(getPanelBackgroundColor()); + textPane.setFocusable(false); + return textPane; + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java index a7b1d794..e9a21eff 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java @@ -12,6 +12,10 @@ public interface ChatToolWindowTabPanel extends Disposable { @Nullable Conversation getConversation(); + TokenDetails getTokenDetails(); + + boolean isStreaming(); + void setConversation(@Nullable Conversation conversation); void displayLandingView(); diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/TokenDetails.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/TokenDetails.java new file mode 100644 index 00000000..1d2ad9ad --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/TokenDetails.java @@ -0,0 +1,49 @@ +package ee.carlrobert.codegpt.toolwindow.chat; + +import ee.carlrobert.codegpt.EncodingManager; +import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; + +public class TokenDetails { + + private final int systemPromptTokens; + private int conversationTokens; + private int userPromptTokens; + private int highlightedTokens; + + public TokenDetails(EncodingManager encodingManager) { + systemPromptTokens = encodingManager.countTokens( + ConfigurationState.getInstance().getSystemPrompt()); + } + + public int getSystemPromptTokens() { + return systemPromptTokens; + } + + public void setConversationTokens(int conversationTokens) { + this.conversationTokens = conversationTokens; + } + + public int getConversationTokens() { + return conversationTokens; + } + + public void setUserPromptTokens(int userPromptTokens) { + this.userPromptTokens = userPromptTokens; + } + + public int getUserPromptTokens() { + return userPromptTokens; + } + + public void setHighlightedTokens(int highlightedTokens) { + this.highlightedTokens = highlightedTokens; + } + + public int getHighlightedTokens() { + return highlightedTokens; + } + + public int getTotal() { + return systemPromptTokens + conversationTokens + userPromptTokens + highlightedTokens; + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ChatMessageResponseBody.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ChatMessageResponseBody.java index 30e347d0..b7391c94 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ChatMessageResponseBody.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ChatMessageResponseBody.java @@ -142,13 +142,17 @@ public class ChatMessageResponseBody extends JPanel { } public void hideCarets() { - if (currentlyProcessedEditor != null) { - ((EditorEx) currentlyProcessedEditor.getEditor()).setCaretVisible(false); - ((EditorEx) currentlyProcessedEditor.getEditor()).setCaretEnabled(false); - } - if (currentlyProcessedTextPane != null && currentlyProcessedTextPane.getCaret().isVisible()) { - currentlyProcessedTextPane.getCaret().setVisible(false); - } + ApplicationManager.getApplication().invokeLater(() -> + ApplicationManager.getApplication().runWriteAction(() -> { + if (currentlyProcessedEditor != null) { + ((EditorEx) currentlyProcessedEditor.getEditor()).setCaretVisible(false); + ((EditorEx) currentlyProcessedEditor.getEditor()).setCaretEnabled(false); + } + if (currentlyProcessedTextPane != null && currentlyProcessedTextPane.getCaret() + .isVisible()) { + currentlyProcessedTextPane.getCaret().setVisible(false); + } + })); } public void displayError(String message) { diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/TotalTokensPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/TotalTokensPanel.java new file mode 100644 index 00000000..44952d3f --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/TotalTokensPanel.java @@ -0,0 +1,105 @@ +package ee.carlrobert.codegpt.toolwindow.chat.components; + +import static java.lang.String.format; + +import com.intellij.icons.AllIcons.General; +import com.intellij.ui.components.JBLabel; +import ee.carlrobert.codegpt.EncodingManager; +import ee.carlrobert.codegpt.conversations.Conversation; +import ee.carlrobert.codegpt.toolwindow.chat.TokenDetails; +import java.awt.FlowLayout; +import java.awt.event.MouseAdapter; +import java.awt.event.MouseEvent; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; +import javax.swing.Box; +import javax.swing.JPanel; +import javax.swing.SwingUtilities; + +public class TotalTokensPanel extends JPanel { + + private final EncodingManager encodingManager; + private final TokenDetails tokenDetails; + private final JBLabel label; + + public TotalTokensPanel(Conversation conversation, String userPrompt, String highlightedText) { + super(new FlowLayout(FlowLayout.LEADING, 0, 0)); + this.encodingManager = EncodingManager.getInstance(); + this.tokenDetails = createTokenDetails(conversation, userPrompt, highlightedText); + this.label = getLabel(tokenDetails); + + setOpaque(false); + add(getContextHelpIcon(tokenDetails)); + add(Box.createHorizontalStrut(4)); + add(label); + } + + public TokenDetails getTokenDetails() { + return tokenDetails; + } + + public void update() { + update(tokenDetails.getTotal()); + } + + public void update(int total) { + label.setText(getLabelHtml(total)); + } + + public void updateConversationTokens(Conversation conversation) { + tokenDetails.setConversationTokens(encodingManager.countConversationTokens(conversation)); + update(); + } + + public void updateUserPromptTokens(String userPrompt) { + tokenDetails.setUserPromptTokens(encodingManager.countTokens(userPrompt)); + update(); + } + + public void updateHighlightedTokens(String highlightedText) { + tokenDetails.setHighlightedTokens(encodingManager.countTokens(highlightedText)); + update(); + } + + private TokenDetails createTokenDetails( + Conversation conversation, + String userPrompt, + String highlightedText) { + var tokenDetails = new TokenDetails(encodingManager); + tokenDetails.setConversationTokens(encodingManager.countConversationTokens(conversation)); + tokenDetails.setUserPromptTokens(encodingManager.countTokens(userPrompt)); + tokenDetails.setHighlightedTokens(encodingManager.countTokens(highlightedText)); + return tokenDetails; + } + + private JBLabel getContextHelpIcon(TokenDetails tokenDetails) { + var iconLabel = new JBLabel(General.ContextHelp); + iconLabel.addMouseListener(new MouseAdapter() { + @Override + public void mouseEntered(MouseEvent e) { + var html = new LinkedHashMap<>(Map.of( + "System Prompt", tokenDetails.getSystemPromptTokens(), + "Conversation Tokens", tokenDetails.getConversationTokens(), + "Input Tokens", tokenDetails.getUserPromptTokens(), + "Highlighted Tokens", tokenDetails.getHighlightedTokens())) + .entrySet().stream() + .map(entry -> format( + "

%s: %d

", + entry.getKey(), + entry.getValue())) + .collect(Collectors.joining()); + iconLabel.setToolTipText("" + html + ""); + } + }); + return iconLabel; + } + + private String getLabelHtml(int total) { + return format("Total Tokens: %d", total); + } + + private JBLabel getLabel(TokenDetails tokenDetails) { + return new JBLabel(getLabelHtml(tokenDetails.getTotal())); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextArea.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextArea.java index bfa71aa4..126cc2d0 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextArea.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextArea.java @@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.toolwindow.chat.components; import com.intellij.icons.AllIcons; import com.intellij.openapi.editor.ex.util.EditorUtil; import com.intellij.openapi.util.registry.Registry; +import com.intellij.ui.DocumentAdapter; import com.intellij.ui.JBColor; import com.intellij.ui.components.JBTextArea; import com.intellij.util.ui.JBUI; @@ -30,7 +31,6 @@ import javax.swing.JPanel; import javax.swing.KeyStroke; import javax.swing.UIManager; import javax.swing.event.DocumentEvent; -import javax.swing.event.DocumentListener; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -38,7 +38,8 @@ public class UserPromptTextArea extends JPanel { private static final String TEXT_SUBMIT = "text-submit"; private static final String INSERT_BREAK = "insert-break"; - private static final JBColor BACKGROUND_COLOR = JBColor.namedColor("Editor.SearchField.background", UIUtil.getTextFieldBackground()); + private static final JBColor BACKGROUND_COLOR = JBColor.namedColor( + "Editor.SearchField.background", UIUtil.getTextFieldBackground()); private final JBTextArea textArea; @@ -48,10 +49,11 @@ public class UserPromptTextArea extends JPanel { private JPanel iconsPanel; private boolean submitEnabled = true; - public UserPromptTextArea(Consumer onSubmit) { + public UserPromptTextArea(Consumer onSubmit, DocumentAdapter documentAdapter) { this.onSubmit = onSubmit; textArea = new JBTextArea(); + textArea.getDocument().addDocumentListener(documentAdapter); textArea.setOpaque(false); textArea.setBackground(BACKGROUND_COLOR); textArea.setLineWrap(true); @@ -78,29 +80,20 @@ public class UserPromptTextArea extends JPanel { UserPromptTextArea.super.paintBorder(UserPromptTextArea.super.getGraphics()); } }); - textArea.getDocument().addDocumentListener(new DocumentListener() { + textArea.getDocument().addDocumentListener(new DocumentAdapter() { @Override - public void removeUpdate(DocumentEvent e) { - if (e.getDocument().getLength() == 0) { - iconsPanel.getComponents()[0].setEnabled(false); - } - } - - @Override - public void insertUpdate(DocumentEvent e) { - if (e.getDocument().getLength() == 1) { - iconsPanel.getComponents()[0].setEnabled(true); - } - } - - @Override - public void changedUpdate(DocumentEvent e) { + protected void textChanged(@NotNull DocumentEvent e) { + iconsPanel.getComponents()[0].setEnabled(e.getDocument().getLength() > 0); } }); updateFont(); init(); } + public String getText() { + return textArea.getText().trim(); + } + public void focus() { textArea.requestFocus(); textArea.requestFocusInWindow(); @@ -136,10 +129,6 @@ public class UserPromptTextArea extends JPanel { stopButton.setEnabled(!submitEnabled); } - public void setTextAreaEnabled(boolean textAreaEnabled) { - textArea.setEnabled(textAreaEnabled); - } - private void handleSubmit() { if (submitEnabled && !textArea.getText().isEmpty()) { // Replacing each newline with two newlines to ensure proper Markdown formatting diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextAreaHeader.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextAreaHeader.java new file mode 100644 index 00000000..60d5c1f6 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextAreaHeader.java @@ -0,0 +1,77 @@ +package ee.carlrobert.codegpt.toolwindow.chat.components; + +import static ee.carlrobert.codegpt.util.ThemeUtils.getPanelBackgroundColor; + +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.project.Project; +import com.intellij.ui.components.JBCheckBox; +import com.intellij.util.messages.MessageBusConnection; +import com.intellij.util.ui.JBUI; +import com.intellij.util.ui.JBUI.Borders; +import ee.carlrobert.codegpt.completions.you.YouSubscriptionNotifier; +import ee.carlrobert.codegpt.completions.you.auth.SignedOutNotifier; +import ee.carlrobert.codegpt.settings.state.SettingsState; +import ee.carlrobert.codegpt.toolwindow.ModelIconLabel; +import ee.carlrobert.codegpt.toolwindow.chat.YouModelChangeNotifier; +import java.awt.BorderLayout; +import javax.swing.JPanel; + +public class UserPromptTextAreaHeader extends JPanel { + + public UserPromptTextAreaHeader( + Project project, + SettingsState settings, + TotalTokensPanel totalTokensPanel, + JBCheckBox gpt4CheckBox) { + super(new BorderLayout()); + setBackground(getPanelBackgroundColor()); + setBorder(JBUI.Borders.emptyBottom(8)); + switch (settings.getSelectedService()) { + case OPENAI: + case AZURE: + add(totalTokensPanel, BorderLayout.LINE_START); + break; + case YOU: + subscribeToYouTopics(project, gpt4CheckBox); + add(gpt4CheckBox, BorderLayout.LINE_START); + break; + } + add(JBUI.Panels + .simplePanel(new ModelIconLabel( + settings.getSelectedService().getCompletionCode(), + settings.getModel())) + .withBorder(Borders.emptyRight(4)) + .withBackground(getPanelBackgroundColor()), BorderLayout.LINE_END); + } + + private void subscribeToYouTopics(Project project, JBCheckBox gpt4CheckBox) { + var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect(); + subscribeToYouModelChangeTopic(project, gpt4CheckBox); + subscribeToYouSubscriptionTopic(messageBusConnection, gpt4CheckBox); + subscribeToSignedOutTopic(messageBusConnection, gpt4CheckBox); + } + + private void subscribeToYouModelChangeTopic(Project project, JBCheckBox gpt4CheckBox) { + project.getMessageBus() + .connect() + .subscribe( + YouModelChangeNotifier.YOU_MODEL_CHANGE_NOTIFIER_TOPIC, + (YouModelChangeNotifier) gpt4CheckBox::setSelected); + } + + private void subscribeToSignedOutTopic( + MessageBusConnection messageBusConnection, + JBCheckBox gpt4CheckBox) { + messageBusConnection.subscribe( + SignedOutNotifier.SIGNED_OUT_TOPIC, + (SignedOutNotifier) () -> gpt4CheckBox.setEnabled(false)); + } + + private void subscribeToYouSubscriptionTopic( + MessageBusConnection messageBusConnection, + JBCheckBox gpt4CheckBox) { + messageBusConnection.subscribe( + YouSubscriptionNotifier.SUBSCRIPTION_TOPIC, + (YouSubscriptionNotifier) () -> gpt4CheckBox.setEnabled(true)); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowTabPanel.java index 28aa9158..6b6dcbfb 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowTabPanel.java @@ -32,6 +32,7 @@ public class StandardChatToolWindowTabPanel extends BaseChatToolWindowTabPanel { displayLandingView(); } else { displayConversation(conversation); + totalTokensPanel.updateConversationTokens(conversation); } } @@ -75,11 +76,11 @@ public class StandardChatToolWindowTabPanel extends BaseChatToolWindowTabPanel { messageResponseBody.displaySerpResults(serpResults); } - var messageWrapper = createNewMessageWrapper(message.getId()); - messageWrapper.add(new UserMessagePanel(project, message, this)); - messageWrapper.add(new ResponsePanel() + var messagePanel = toolWindowScrollablePanel.addMessage(message.getId()); + messagePanel.add(new UserMessagePanel(project, message, this)); + messagePanel.add(new ResponsePanel() .withReloadAction(() -> reloadMessage(message, conversation)) - .withDeleteAction(() -> removeMessage(message.getId(), messageWrapper, conversation)) + .withDeleteAction(() -> removeMessage(message.getId(), conversation)) .addContent(messageResponseBody)); }); setConversation(conversation); diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowTabbedPane.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowTabbedPane.java index 8e7dafa1..be12e77f 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowTabbedPane.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowTabbedPane.java @@ -112,6 +112,7 @@ public class StandardChatToolWindowTabbedPane extends JBTabbedPane { tryFindActiveTabPanel().ifPresent(tabPanel -> { tabPanel.displayLandingView(); tabPanel.setConversation(null); + tabPanel.updateConversationTokens(); }); ConversationsState.getInstance().setCurrentConversation(null); } diff --git a/src/main/java/ee/carlrobert/codegpt/util/EditorUtils.java b/src/main/java/ee/carlrobert/codegpt/util/EditorUtils.java index b1e4e770..f3bfa1c4 100644 --- a/src/main/java/ee/carlrobert/codegpt/util/EditorUtils.java +++ b/src/main/java/ee/carlrobert/codegpt/util/EditorUtils.java @@ -52,6 +52,15 @@ public final class EditorUtils { return editorManager != null ? editorManager.getSelectedTextEditor() : null; } + public static @NotNull String getSelectedEditorSelectedText(@NotNull Project project) { + var selectedEditor = EditorUtils.getSelectedEditor(project); + var selectedText = ""; + if (selectedEditor != null) { + selectedText = selectedEditor.getSelectionModel().getSelectedText(); + } + return selectedText == null ? "" : selectedText; + } + public static boolean isMainEditorTextSelected(@NotNull Project project) { return hasSelection(getSelectedEditor(project)); } diff --git a/src/main/java/ee/carlrobert/codegpt/util/SwingUtils.java b/src/main/java/ee/carlrobert/codegpt/util/SwingUtils.java index 0b4f1f83..128306ab 100644 --- a/src/main/java/ee/carlrobert/codegpt/util/SwingUtils.java +++ b/src/main/java/ee/carlrobert/codegpt/util/SwingUtils.java @@ -3,7 +3,10 @@ package ee.carlrobert.codegpt.util; import static javax.swing.event.HyperlinkEvent.EventType.ACTIVATED; import com.intellij.ide.BrowserUtil; +import com.intellij.openapi.roots.ui.componentsList.components.ScrollablePanel; +import com.intellij.ui.ScrollPaneFactory; import com.intellij.util.ui.UI; +import ee.carlrobert.codegpt.toolwindow.chat.components.SmartScroller; import java.awt.Dimension; import java.awt.event.ActionEvent; import java.net.URISyntaxException; @@ -14,9 +17,11 @@ import javax.swing.JButton; import javax.swing.JComponent; import javax.swing.JLabel; import javax.swing.JPanel; +import javax.swing.JScrollPane; import javax.swing.JTextArea; import javax.swing.JTextPane; import javax.swing.KeyStroke; +import javax.swing.ScrollPaneConstants; import javax.swing.event.HyperlinkEvent; import javax.swing.event.HyperlinkListener; @@ -44,6 +49,13 @@ public class SwingUtils { return button; } + public static JScrollPane createScrollPaneWithSmartScroller(ScrollablePanel scrollablePanel) { + var scrollPane = ScrollPaneFactory.createScrollPane(scrollablePanel, true); + scrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER); + new SmartScroller(scrollPane); + return scrollPane; + } + public static void setEqualLabelWidths(JPanel firstPanel, JPanel secondPanel) { var firstLabel = firstPanel.getComponents()[0]; var secondLabel = secondPanel.getComponents()[0]; diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index 889141e7..ce17b4f7 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -43,6 +43,7 @@ icon="ee.carlrobert.codegpt.Icons.DefaultSmallIcon"> + diff --git a/src/main/resources/icons/sparkle.svg b/src/main/resources/icons/sparkle.svg new file mode 100644 index 00000000..1c002722 --- /dev/null +++ b/src/main/resources/icons/sparkle.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java b/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java index dbfbaddb..0739bd30 100644 --- a/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java +++ b/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java @@ -9,7 +9,6 @@ import static org.apache.http.HttpHeaders.AUTHORIZATION; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.groups.Tuple.tuple; -import com.intellij.testFramework.fixtures.BasePlatformTestCase; import ee.carlrobert.codegpt.conversations.ConversationService; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; @@ -17,34 +16,18 @@ import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; -import ee.carlrobert.llm.client.http.LocalCallbackServer; import ee.carlrobert.llm.client.http.ResponseEntity; import ee.carlrobert.llm.client.http.exchange.BasicHttpExchange; -import ee.carlrobert.llm.client.http.expectation.BasicExpectation; import ee.carlrobert.llm.client.openai.completion.chat.OpenAIChatCompletionModel; import java.util.List; import java.util.Map; +import testsupport.IntegrationTest; -public class CompletionRequestProviderTest extends BasePlatformTestCase { - - private LocalCallbackServer server; - - @Override - protected void setUp() throws Exception { - super.setUp(); - server = new LocalCallbackServer(8000); - OpenAISettingsState.getInstance().setBaseHost("http://127.0.0.1:8000"); - OpenAICredentialsManager.getInstance().setApiKey("TEST_API_KEY"); - ConfigurationState.getInstance().setSystemPrompt(""); - } - - @Override - protected void tearDown() throws Exception { - server.stop(); - super.tearDown(); - } +public class CompletionRequestProviderTest extends IntegrationTest { public void testChatCompletionRequestWithSystemPromptOverride() { + OpenAICredentialsManager.getInstance().setApiKey("TEST_API_KEY"); + OpenAISettingsState.getInstance().setBaseHost(null); ConfigurationState.getInstance().setSystemPrompt("TEST_SYSTEM_PROMPT"); var conversation = ConversationService.getInstance().startConversation(); var firstMessage = createDummyMessage(500); @@ -78,7 +61,8 @@ public class CompletionRequestProviderTest extends BasePlatformTestCase { conversation.addMessage(secondMessage); var request = new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), new Message("TEST_CHAT_COMPLETION_PROMPT"), false); + .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), + new Message("TEST_CHAT_COMPLETION_PROMPT"), false); assertThat(request.getMessages()) .extracting("role", "content") @@ -92,6 +76,7 @@ public class CompletionRequestProviderTest extends BasePlatformTestCase { } public void testChatCompletionRequestRetry() { + ConfigurationState.getInstance().setSystemPrompt(COMPLETION_SYSTEM_PROMPT); var conversation = ConversationService.getInstance().startConversation(); var firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500); var secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250); @@ -99,7 +84,8 @@ public class CompletionRequestProviderTest extends BasePlatformTestCase { conversation.addMessage(secondMessage); var request = new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), secondMessage, true); + .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), + secondMessage, true); assertThat(request.getMessages()) .extracting("role", "content") @@ -121,7 +107,8 @@ public class CompletionRequestProviderTest extends BasePlatformTestCase { conversation.discardTokenLimits(); var request = new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), new Message("TEST_CHAT_COMPLETION_PROMPT"), false); + .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), + new Message("TEST_CHAT_COMPLETION_PROMPT"), false); assertThat(request.getMessages()) .extracting("role", "content") @@ -140,13 +127,15 @@ public class CompletionRequestProviderTest extends BasePlatformTestCase { assertThrows(TotalUsageExceededException.class, () -> new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), createDummyMessage(100), false)); + .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), + createDummyMessage(100), false)); } public void testContextualSearch() { + useOpenAIService(); var conversation = ConversationService.getInstance().startConversation(); - SettingsState.getInstance().setSelectedService(ServiceType.OPENAI); - expectRequest("/v1/chat/completions", request -> { + expectOpenAI((BasicHttpExchange) request -> { + assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions"); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); assertThat(request.getBody()) @@ -154,46 +143,57 @@ public class CompletionRequestProviderTest extends BasePlatformTestCase { .containsExactly("gpt-4", List.of(Map.of( "role", "user", - "content", "You are Text Generator, a helpful expert of generating natural language into semantically comparable search query.\n" + + "content", + "You are Text Generator, a helpful expert of generating natural language into semantically comparable search query.\n" + + "\n" + "Text: List all the dependencies that the project uses\n" + - "AI: project dependencies, development dependencies, versions, libraries, frameworks, packages\n" + + "AI: project dependencies, development dependencies, versions, libraries, frameworks, packages\n" + + "\n" + - "Text: Are there any scheduled tasks or background jobs running in our codebase, and if so, what are they responsible for?\n" + - "AI: scheduled tasks, background jobs, cron jobs, task schedules, codebase tasks\n" + + "Text: Are there any scheduled tasks or background jobs running in our codebase, and if so, what are they responsible for?\n" + + + "AI: scheduled tasks, background jobs, cron jobs, task schedules, codebase tasks\n" + + "\n" + "Text: TEST_CHAT_COMPLETION_PROMPT\n" + "AI:"))); return new ResponseEntity(200, - jsonMapResponse("choices", jsonArray(jsonMap("message", jsonMap(e("role", "assistant"), e("content", "TEST_CHAT_COMPLETION_RESPONSE")))))); + jsonMapResponse("choices", jsonArray(jsonMap("message", + jsonMap(e("role", "assistant"), e("content", "TEST_CHAT_COMPLETION_RESPONSE")))))); }); - expectRequest("/v1/embeddings", request -> { + expectOpenAI((BasicHttpExchange) request -> { + assertThat(request.getUri().getPath()).isEqualTo("/v1/embeddings"); var headers = request.getHeaders(); assertThat(headers.get("Authorization").get(0)).isEqualTo("Bearer TEST_API_KEY"); assertThat(request.getBody()) .extracting("model", "input") .containsExactly("text-embedding-ada-002", List.of("TEST_CHAT_COMPLETION_RESPONSE")); - return new ResponseEntity(200, jsonMapResponse("data", jsonArray(jsonMap("embedding", List.of(-0.00692, -0.0053, -4.5471, -0.0240))))); + return new ResponseEntity(200, jsonMapResponse("data", + jsonArray(jsonMap("embedding", List.of(-0.00692, -0.0053, -4.5471, -0.0240))))); }); var request = new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), new Message("TEST_CHAT_COMPLETION_PROMPT"), false, true, null); + .buildOpenAIChatCompletionRequest(OpenAIChatCompletionModel.GPT_3_5.getCode(), + new Message("TEST_CHAT_COMPLETION_PROMPT"), false, true, null); assertThat(request.getModel()).isEqualTo("gpt-3.5-turbo"); assertThat(request.getMessages().size()).isEqualTo(1); assertThat(request.getMessages().get(0)) .extracting("role", "content") - .containsExactly("user", "Use the following pieces of context to answer the question at the end.\n" + - "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n" + - "\n" + - "Context:\n" + - "\n" + - "TEST_CONTEXT\n" + - "\n" + - "Question: TEST_CHAT_COMPLETION_PROMPT\n" + - "\n" + - "Helpful answer in Markdown format:"); + .containsExactly("user", + "Use the following pieces of context to answer the question at the end.\n" + + "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n" + + + "\n" + + "Context:\n" + + "\n" + + "TEST_CONTEXT\n" + + "\n" + + "Question: TEST_CHAT_COMPLETION_PROMPT\n" + + "\n" + + "Helpful answer in Markdown format:"); } private Message createDummyMessage(int tokenSize) { @@ -206,8 +206,4 @@ public class CompletionRequestProviderTest extends BasePlatformTestCase { message.setResponse("zz".repeat((tokenSize) - 6 - 7)); return message; } - - private void expectRequest(String path, BasicHttpExchange exchange) { - server.addExpectation(new BasicExpectation(path, exchange)); - } } diff --git a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java index d0edd194..8ba13fb5 100644 --- a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java +++ b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java @@ -11,57 +11,24 @@ import static org.apache.http.HttpHeaders.AUTHORIZATION; import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; -import com.intellij.testFramework.fixtures.BasePlatformTestCase; import ee.carlrobert.codegpt.CodeGPTPlugin; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.ConversationService; import ee.carlrobert.codegpt.conversations.message.Message; -import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; -import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; -import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; -import ee.carlrobert.codegpt.settings.service.ServiceType; -import ee.carlrobert.codegpt.settings.state.AzureSettingsState; -import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; -import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; -import ee.carlrobert.codegpt.settings.state.SettingsState; -import ee.carlrobert.codegpt.settings.state.YouSettingsState; -import ee.carlrobert.llm.client.http.LocalCallbackServer; import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange; -import ee.carlrobert.llm.client.http.expectation.StreamExpectation; -import ee.carlrobert.llm.client.openai.completion.chat.OpenAIChatCompletionModel; import java.util.List; import java.util.Map; +import testsupport.IntegrationTest; -public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { - - private LocalCallbackServer server; - - @Override - protected void setUp() throws Exception { - super.setUp(); - AzureCredentialsManager.getInstance().setApiKey("TEST_API_KEY"); - OpenAICredentialsManager.getInstance().setApiKey("TEST_API_KEY"); - // FIXME - OpenAISettingsState.getInstance().setBaseHost("http://127.0.0.1:8000"); - AzureSettingsState.getInstance().setBaseHost("http://127.0.0.1:8000"); - YouSettingsState.getInstance().setBaseHost("http://127.0.0.1:8000"); - LlamaSettingsState.getInstance().setServerPort(8000); - ConfigurationState.getInstance().setSystemPrompt(""); - server = new LocalCallbackServer(8000); - } - - @Override - protected void tearDown() throws Exception { - server.stop(); - super.tearDown(); - } +public class DefaultCompletionRequestHandlerTest extends IntegrationTest { public void testOpenAIChatCompletionCall() { + useOpenAIService(); var message = new Message("TEST_PROMPT"); var conversation = ConversationService.getInstance().startConversation(); var requestHandler = new CompletionRequestHandler(false, getRequestEventListener(message)); - SettingsState.getInstance().setSelectedService(ServiceType.OPENAI); - expectStreamRequest("/v1/chat/completions", request -> { + expectOpenAI((StreamHttpExchange) request -> { + assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions"); assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); assertThat(request.getBody()) @@ -69,7 +36,7 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { "model", "messages") .containsExactly( - "gpt-3.5-turbo", + "gpt-4", List.of( Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT), Map.of("role", "user", "content", "TEST_PROMPT"))); @@ -87,11 +54,7 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { } public void testAzureChatCompletionCall() { - SettingsState.getInstance().setSelectedService(ServiceType.AZURE); - var azureSettings = AzureSettingsState.getInstance(); - azureSettings.setResourceName("TEST_RESOURCE_NAME"); - azureSettings.setApiVersion("TEST_API_VERSION"); - azureSettings.setDeploymentId("TEST_DEPLOYMENT_ID"); + useAzureService(); var conversationService = ConversationService.getInstance(); var message = new Message("TEST_PROMPT"); var requestHandler = new CompletionRequestHandler(false, getRequestEventListener(message)); @@ -100,7 +63,9 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { var conversation = conversationService.startConversation(); conversation.addMessage(prevMessage); conversationService.saveConversation(conversation); - expectStreamRequest("/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions", request -> { + expectAzure((StreamHttpExchange) request -> { + assertThat(request.getUri().getPath()).isEqualTo( + "/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions"); assertThat(request.getUri().getQuery()).isEqualTo("api-version=TEST_API_VERSION"); assertThat(request.getHeaders().get("Api-key").get(0)).isEqualTo("TEST_API_KEY"); assertThat(request.getBody()) @@ -124,12 +89,13 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { } public void testYouChatCompletionCall() { + useYouService(); var message = new Message("TEST_PROMPT"); var conversation = ConversationService.getInstance().startConversation(); conversation.addMessage(new Message("Ping", "Pong")); var requestHandler = new CompletionRequestHandler(false, getRequestEventListener(message)); - SettingsState.getInstance().setSelectedService(ServiceType.YOU); - expectStreamRequest("/api/streamingSearch", request -> { + expectYou((StreamHttpExchange) request -> { + assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch"); assertThat(request.getMethod()).isEqualTo("GET"); assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch"); assertThat(request.getUri().getQuery()).isEqualTo( @@ -145,8 +111,8 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { "utm_campaign=" + CodeGPTPlugin.getVersion() + "&" + "utm_content=CodeGPT"); assertThat(request.getHeaders()) - .flatExtracting("Host", "Accept", "Connection", "User-agent", "Cookie") - .containsExactly("127.0.0.1:8000", + .flatExtracting("Accept", "Connection", "User-agent", "Cookie") + .containsExactly( "text/event-stream", "Keep-Alive", "youide CodeGPT", @@ -173,12 +139,13 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { } public void testLlamaChatCompletionCall() { + useLlamaService(); var message = new Message("TEST_PROMPT"); var conversation = ConversationService.getInstance().startConversation(); conversation.addMessage(new Message("Ping", "Pong")); var requestHandler = new CompletionRequestHandler(false, getRequestEventListener(message)); - SettingsState.getInstance().setSelectedService(ServiceType.LLAMA_CPP); - expectStreamRequest("/completion", request -> { + expectLlama((StreamHttpExchange) request -> { + assertThat(request.getUri().getPath()).isEqualTo("/completion"); assertThat(request.getBody()) .extracting( "prompt", @@ -204,10 +171,6 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse())); } - private void expectStreamRequest(String path, StreamHttpExchange exchange) { - server.addExpectation(new StreamExpectation(path, exchange)); - } - private ToolWindowCompletionEventListener getRequestEventListener(Message message) { return new ToolWindowCompletionEventListener() { @Override diff --git a/src/test/java/ee/carlrobert/codegpt/settings/state/SettingsStateTest.java b/src/test/java/ee/carlrobert/codegpt/settings/state/SettingsStateTest.java index f032bd7f..ae9bffcc 100644 --- a/src/test/java/ee/carlrobert/codegpt/settings/state/SettingsStateTest.java +++ b/src/test/java/ee/carlrobert/codegpt/settings/state/SettingsStateTest.java @@ -1,6 +1,5 @@ package ee.carlrobert.codegpt.settings.state; - import static ee.carlrobert.codegpt.completions.HuggingFaceModel.CODE_LLAMA_7B_Q3; import static org.assertj.core.api.Assertions.assertThat; diff --git a/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/StandardChatToolWindowTabPanelTest.java b/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/StandardChatToolWindowTabPanelTest.java new file mode 100644 index 00000000..be9ff113 --- /dev/null +++ b/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/StandardChatToolWindowTabPanelTest.java @@ -0,0 +1,78 @@ +package ee.carlrobert.codegpt.toolwindow.chat; + +import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT; +import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray; +import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap; +import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.apache.http.HttpHeaders.AUTHORIZATION; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import ee.carlrobert.codegpt.EncodingManager; +import ee.carlrobert.codegpt.conversations.ConversationService; +import ee.carlrobert.codegpt.conversations.message.Message; +import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; +import ee.carlrobert.codegpt.toolwindow.chat.standard.StandardChatToolWindowTabPanel; +import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange; +import java.util.List; +import java.util.Map; +import testsupport.IntegrationTest; + +public class StandardChatToolWindowTabPanelTest extends IntegrationTest { + + public void testSendingOpenAIMessage() { + useOpenAIService(); + ConfigurationState.getInstance().setSystemPrompt(COMPLETION_SYSTEM_PROMPT); + var message = new Message("Hello!"); + var conversation = ConversationService.getInstance().startConversation(); + var panel = new StandardChatToolWindowTabPanel(getProject(), conversation); + expectOpenAI((StreamHttpExchange) request -> { + assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions"); + assertThat(request.getMethod()).isEqualTo("POST"); + assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); + assertThat(request.getBody()) + .extracting( + "model", + "messages") + .containsExactly( + "gpt-4", + List.of( + Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT), + Map.of("role", "user", "content", "Hello!"))); + return List.of( + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))); + }); + + panel.sendMessage(message); + + await().atMost(5, SECONDS).until(() -> !panel.isStreaming()); + var encodingManager = EncodingManager.getInstance(); + assertThat(panel.getTokenDetails()).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens") + .containsExactly( + encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), + encodingManager.countConversationTokens(panel.getConversation()), + 0, + 0); + assertThat(panel.getConversation()) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.getId(), + conversation.getModel(), + conversation.getClientCode(), + false); + var messages = panel.getConversation().getMessages(); + assertThat(messages.size()).isOne(); + assertThat(messages.get(0)) + .extracting("id", "prompt", "response") + .containsExactly(message.getId(), message.getPrompt(), message.getResponse()); + } +} diff --git a/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/StandardChatToolWindowTabbedPaneTest.java b/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/StandardChatToolWindowTabbedPaneTest.java index 543fec3c..02416af9 100644 --- a/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/StandardChatToolWindowTabbedPaneTest.java +++ b/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/StandardChatToolWindowTabbedPaneTest.java @@ -8,6 +8,7 @@ import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.toolwindow.chat.standard.StandardChatToolWindowTabPanel; import ee.carlrobert.codegpt.toolwindow.chat.standard.StandardChatToolWindowTabbedPane; +import testsupport.IntegrationTest; public class StandardChatToolWindowTabbedPaneTest extends BasePlatformTestCase { @@ -29,7 +30,8 @@ public class StandardChatToolWindowTabbedPaneTest extends BasePlatformTestCase { var tabMapping = tabbedPane.getActiveTabMapping(); assertThat(tabMapping.keySet()).containsExactly("Chat 1", "Chat 2", "Chat 3"); - assertThat(tabMapping.values().stream().allMatch(item -> item.getConversation() == null)).isTrue(); + assertThat( + tabMapping.values().stream().allMatch(item -> item.getConversation() == null)).isTrue(); } public void testResetCurrentlyActiveTabPanel() { diff --git a/src/test/java/testsupport/IntegrationTest.java b/src/test/java/testsupport/IntegrationTest.java new file mode 100644 index 00000000..19f266a1 --- /dev/null +++ b/src/test/java/testsupport/IntegrationTest.java @@ -0,0 +1,20 @@ +package testsupport; + +import com.intellij.testFramework.fixtures.BasePlatformTestCase; +import ee.carlrobert.llm.client.mixin.ExternalServiceTestMixin; +import org.junit.jupiter.api.AfterEach; +import testsupport.mixin.ShortcutsTestMixin; + +public class IntegrationTest extends BasePlatformTestCase implements + ExternalServiceTestMixin, + ShortcutsTestMixin { + + static { + ExternalServiceTestMixin.init(); + } + + @AfterEach + public void cleanUpEach() { + ExternalServiceTestMixin.clearAll(); + } +} diff --git a/src/test/java/testsupport/mixin/ShortcutsTestMixin.java b/src/test/java/testsupport/mixin/ShortcutsTestMixin.java new file mode 100644 index 00000000..5bc1b76e --- /dev/null +++ b/src/test/java/testsupport/mixin/ShortcutsTestMixin.java @@ -0,0 +1,38 @@ +package testsupport.mixin; + +import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; +import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; +import ee.carlrobert.codegpt.settings.service.ServiceType; +import ee.carlrobert.codegpt.settings.state.AzureSettingsState; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; +import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; +import ee.carlrobert.codegpt.settings.state.SettingsState; + +public interface ShortcutsTestMixin { + + default void useOpenAIService() { + SettingsState.getInstance().setSelectedService(ServiceType.OPENAI); + OpenAICredentialsManager.getInstance().setApiKey("TEST_API_KEY"); + OpenAISettingsState.getInstance().setModel("gpt-4"); + OpenAISettingsState.getInstance().setBaseHost(null); + } + + default void useAzureService() { + SettingsState.getInstance().setSelectedService(ServiceType.AZURE); + AzureCredentialsManager.getInstance().setApiKey("TEST_API_KEY"); + var azureSettings = AzureSettingsState.getInstance(); + azureSettings.setBaseHost(null); + azureSettings.setResourceName("TEST_RESOURCE_NAME"); + azureSettings.setApiVersion("TEST_API_VERSION"); + azureSettings.setDeploymentId("TEST_DEPLOYMENT_ID"); + } + + default void useYouService() { + SettingsState.getInstance().setSelectedService(ServiceType.YOU); + } + + default void useLlamaService() { + SettingsState.getInstance().setSelectedService(ServiceType.LLAMA_CPP); + LlamaSettingsState.getInstance().setServerPort(null); + } +} diff --git a/src/test/resources/application.properties b/src/test/resources/application.properties new file mode 100644 index 00000000..e69de29b