feat: improve chat UI performance

This commit is contained in:
Carl-Robert Linnupuu 2024-08-02 03:12:43 +03:00
parent c211423b9d
commit 0584c31530
6 changed files with 142 additions and 153 deletions

View file

@ -4,19 +4,15 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.GeneralSettingsState;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.util.List;
import javax.swing.SwingWorker;
import okhttp3.sse.EventSource;
public class CompletionRequestHandler {
private final StringBuilder messageBuilder = new StringBuilder();
private final CompletionResponseEventListener completionResponseEventListener;
private SwingWorker<Void, String> swingWorker;
private EventSource eventSource;
public CompletionRequestHandler(CompletionResponseEventListener completionResponseEventListener) {
@ -24,15 +20,21 @@ public class CompletionRequestHandler {
}
public void call(CallParameters callParameters) {
swingWorker = new CompletionRequestWorker(callParameters);
swingWorker.execute();
try {
eventSource = startCall(callParameters, new RequestCompletionEventListener(callParameters));
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
callParameters.getMessage());
} finally {
sendInfo(callParameters);
}
}
public void cancel() {
if (eventSource != null) {
eventSource.cancel();
}
swingWorker.cancel(true);
}
private EventSource startCall(
@ -57,79 +59,48 @@ public class CompletionRequestHandler {
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
}
private class CompletionRequestWorker extends SwingWorker<Void, String> {
class RequestCompletionEventListener implements CompletionEventListener<String> {
private final CallParameters callParameters;
public CompletionRequestWorker(CallParameters callParameters) {
public RequestCompletionEventListener(CallParameters callParameters) {
this.callParameters = callParameters;
}
protected Void doInBackground() {
var settings = GeneralSettings.getCurrentState();
@Override
public void onEvent(String data) {
try {
eventSource = startCall(callParameters, new RequestCompletionEventListener());
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
callParameters.getMessage());
} finally {
sendInfo(settings);
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
completionResponseEventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
return null;
}
protected void process(List<String> chunks) {
@Override
public void onMessage(String message, EventSource eventSource) {
messageBuilder.append(message);
callParameters.getMessage().setResponse(messageBuilder.toString());
for (String text : chunks) {
messageBuilder.append(text);
completionResponseEventListener.handleMessage(text);
}
completionResponseEventListener.handleMessage(message);
}
class RequestCompletionEventListener implements CompletionEventListener<String> {
@Override
public void onEvent(String data) {
try {
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
completionResponseEventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
}
@Override
public void onMessage(String message, EventSource eventSource) {
publish(message);
}
@Override
public void onComplete(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onCancelled(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
completionResponseEventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
}
@Override
public void onComplete(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
private void sendInfo(GeneralSettingsState settings) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.property("service", settings.getSelectedService().getCode().toLowerCase())
.send();
@Override
public void onCancelled(StringBuilder messageBuilder) {
completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters);
}
@Override
public void onError(ErrorDetails error, Throwable ex) {
try {
completionResponseEventListener.handleError(error, ex);
} finally {
sendError(error, ex);
}
}
private void sendError(ErrorDetails error, Throwable ex) {
@ -147,4 +118,12 @@ public class CompletionRequestHandler {
telemetryMessage.send();
}
}
private void sendInfo(CallParameters callParameters) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())
.property("service", GeneralSettings.getSelectedService().getCode().toLowerCase())
.send();
}
}

View file

@ -5,12 +5,12 @@ import static ee.carlrobert.codegpt.ui.UIUtil.createScrollPaneWithSmartScroller;
import static java.lang.String.format;
import com.intellij.openapi.Disposable;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.project.Project;
import com.intellij.ui.JBColor;
import com.intellij.util.ui.JBUI;
import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.ReferencedFile;
import ee.carlrobert.codegpt.actions.ActionType;
import ee.carlrobert.codegpt.completions.CallParameters;
@ -41,7 +41,6 @@ import java.nio.file.Path;
import java.util.UUID;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.SwingUtilities;
import kotlin.Unit;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@ -111,7 +110,7 @@ public class ChatToolWindowTabPanel implements Disposable {
}
public void sendMessage(Message message, ConversationType conversationType) {
SwingUtilities.invokeLater(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
var referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES);
var chatToolWindowPanel = project.getService(ChatToolWindowContentManager.class)
.tryFindChatToolWindowPanel();
@ -127,6 +126,7 @@ public class ChatToolWindowTabPanel implements Disposable {
chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project));
}
totalTokensPanel.updateConversationTokens(conversation);
var userMessagePanel = new UserMessagePanel(project, message, this);
var attachedFilePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project);
@ -142,7 +142,6 @@ public class ChatToolWindowTabPanel implements Disposable {
var responsePanel = createResponsePanel(message, conversationType);
messagePanel.add(responsePanel);
updateTotalTokens(message);
call(callParameters, responsePanel);
});
}
@ -163,12 +162,6 @@ public class ChatToolWindowTabPanel implements Disposable {
return callParameters;
}
private void updateTotalTokens(Message message) {
int userPromptTokens = EncodingManager.getInstance().countTokens(message.getPrompt());
int conversationTokens = EncodingManager.getInstance().countConversationTokens(conversation);
totalTokensPanel.updateConversationTokens(conversationTokens + userPromptTokens);
}
private ResponsePanel createResponsePanel(Message message, ConversationType conversationType) {
return new ResponsePanel()
.withReloadAction(() -> reloadMessage(message, conversation, conversationType))

View file

@ -18,7 +18,6 @@ import ee.carlrobert.codegpt.toolwindow.chat.ui.textarea.TotalTokensPanel;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import javax.swing.SwingUtilities;
abstract class ToolWindowCompletionResponseEventListener implements
CompletionResponseEventListener {
@ -54,17 +53,16 @@ abstract class ToolWindowCompletionResponseEventListener implements
@Override
public void handleMessage(String partialMessage) {
try {
ApplicationManager.getApplication()
.invokeLater(() -> {
responseContainer.update(partialMessage);
messageBuilder.append(partialMessage);
responseContainer.update(partialMessage);
messageBuilder.append(partialMessage);
if (!completed) {
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
totalTokensPanel.update(
totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
}
});
if (!completed) {
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
ApplicationManager.getApplication().invokeLater(() -> {
totalTokensPanel.update(
totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
});
}
} catch (Exception e) {
responseContainer.displayError("Something went wrong.");
throw new RuntimeException("Error while updating the content", e);
@ -73,7 +71,7 @@ abstract class ToolWindowCompletionResponseEventListener implements
@Override
public void handleError(ErrorDetails error, Throwable ex) {
SwingUtilities.invokeLater(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
try {
if ("insufficient_quota".equals(error.getCode())) {
responseContainer.displayQuotaExceeded();
@ -90,7 +88,7 @@ abstract class ToolWindowCompletionResponseEventListener implements
@Override
public void handleTokensExceeded(Conversation conversation, Message message) {
SwingUtilities.invokeLater(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
var answer = OverlayUtil.showTokenLimitExceededDialog();
if (answer == OK) {
TelemetryAction.IDE_ACTION.createActionMessage()
@ -110,7 +108,7 @@ abstract class ToolWindowCompletionResponseEventListener implements
public void handleCompleted(String fullMessage, CallParameters callParameters) {
conversationService.saveMessage(fullMessage, callParameters);
SwingUtilities.invokeLater(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
try {
responsePanel.enableActions();
totalTokensPanel.updateUserPromptTokens(textArea.getText());
@ -123,7 +121,8 @@ abstract class ToolWindowCompletionResponseEventListener implements
@Override
public void handleCodeGPTEvent(CodeGPTEvent event) {
responseContainer.displayWebSearchItem(event.getEvent().getDetails());
ApplicationManager.getApplication().invokeLater(() ->
responseContainer.displayWebSearchItem(event.getEvent().getDetails()));
}
private void stopStreaming(ChatMessageResponseBody responseContainer) {

View file

@ -6,6 +6,7 @@ import static java.lang.String.format;
import static javax.swing.event.HyperlinkEvent.EventType.ACTIVATED;
import com.intellij.openapi.Disposable;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.fileEditor.FileEditorManager;
import com.intellij.openapi.options.ShowSettingsUtil;
import com.intellij.openapi.project.Project;
@ -99,59 +100,65 @@ public class ChatMessageResponseBody extends JPanel {
}
public void displayMissingCredential() {
var message = "API key not provided. Open <a href=\"#\">Settings</a> to set one.";
currentlyProcessedTextPane.setText(
format("<html><p style=\"margin-top: 4px; margin-bottom: 8px;\">%s</p></html>", message));
currentlyProcessedTextPane.addHyperlinkListener(e -> {
if (e.getEventType() == ACTIVATED) {
ShowSettingsUtil.getInstance()
.showSettingsDialog(project, GeneralSettingsConfigurable.class);
ApplicationManager.getApplication().invokeLater(() -> {
var message = "API key not provided. Open <a href=\"#\">Settings</a> to set one.";
currentlyProcessedTextPane.setText(
format("<html><p style=\"margin-top: 4px; margin-bottom: 8px;\">%s</p></html>", message));
currentlyProcessedTextPane.addHyperlinkListener(e -> {
if (e.getEventType() == ACTIVATED) {
ShowSettingsUtil.getInstance()
.showSettingsDialog(project, GeneralSettingsConfigurable.class);
}
});
hideCaret();
if (webpageListPanel != null) {
webpageListPanel.setVisible(false);
}
});
hideCaret();
if (webpageListPanel != null) {
webpageListPanel.setVisible(false);
}
}
public void displayQuotaExceeded() {
currentlyProcessedTextPane.setText("<html>"
+ "<p style=\"margin-top: 4px; margin-bottom: 8px;\">"
+ "You exceeded your current quota, please check your plan and billing details, "
+ "or <a href=\"#CHANGE_PROVIDER\">change</a> to a different LLM provider.</p>"
+ "</html>");
ApplicationManager.getApplication().invokeLater(() -> {
currentlyProcessedTextPane.setText("<html>"
+ "<p style=\"margin-top: 4px; margin-bottom: 8px;\">"
+ "You exceeded your current quota, please check your plan and billing details, "
+ "or <a href=\"#CHANGE_PROVIDER\">change</a> to a different LLM provider.</p>"
+ "</html>");
currentlyProcessedTextPane.addHyperlinkListener(e -> {
if (e.getEventType() == ACTIVATED) {
ShowSettingsUtil.getInstance()
.showSettingsDialog(project, GeneralSettingsConfigurable.class);
TelemetryAction.IDE_ACTION.createActionMessage()
.property("action", ActionType.CHANGE_PROVIDER.name())
.send();
currentlyProcessedTextPane.addHyperlinkListener(e -> {
if (e.getEventType() == ACTIVATED) {
ShowSettingsUtil.getInstance()
.showSettingsDialog(project, GeneralSettingsConfigurable.class);
TelemetryAction.IDE_ACTION.createActionMessage()
.property("action", ActionType.CHANGE_PROVIDER.name())
.send();
}
});
hideCaret();
if (webpageListPanel != null) {
webpageListPanel.setVisible(false);
}
});
hideCaret();
if (webpageListPanel != null) {
webpageListPanel.setVisible(false);
}
}
public void displayError(String message) {
var errorText = format(
"<html><p style=\"margin-top: 4px; margin-bottom: 8px;\">%s</p></html>",
message);
if (responseReceived) {
add(createTextPane(errorText, false));
} else {
currentlyProcessedTextPane.setText(errorText);
}
hideCaret();
ApplicationManager.getApplication().invokeLater(() -> {
var errorText = format(
"<html><p style=\"margin-top: 4px; margin-bottom: 8px;\">%s</p></html>",
message);
if (responseReceived) {
add(createTextPane(errorText, false));
} else {
currentlyProcessedTextPane.setText(errorText);
}
hideCaret();
if (webpageListPanel != null) {
webpageListPanel.setVisible(false);
}
if (webpageListPanel != null) {
webpageListPanel.setVisible(false);
}
});
}
public void displayWebSearchItem(Details details) {
@ -196,19 +203,23 @@ public class ChatMessageResponseBody extends JPanel {
var codeBlock = ((FencedCodeBlock) child);
var code = codeBlock.getContentChars().unescape();
if (!code.isEmpty()) {
if (currentlyProcessedEditorPanel == null) {
prepareProcessingCode(code, codeBlock.getInfo().unescape());
}
EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code);
ApplicationManager.getApplication().invokeLater(() -> {
if (currentlyProcessedEditorPanel == null) {
prepareProcessingCode(code, codeBlock.getInfo().unescape());
}
EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code);
});
}
}
}
private void processText(String markdownText, boolean caretVisible) {
if (currentlyProcessedTextPane == null) {
prepareProcessingText(caretVisible);
}
currentlyProcessedTextPane.setText(convertMdToHtml(markdownText));
ApplicationManager.getApplication().invokeLater(() -> {
if (currentlyProcessedTextPane == null) {
prepareProcessingText(caretVisible);
}
currentlyProcessedTextPane.setText(convertMdToHtml(markdownText));
});
}
private void prepareProcessingText(boolean caretVisible) {
@ -244,19 +255,17 @@ public class ChatMessageResponseBody extends JPanel {
}
private static JPanel createWebpageListPanel(WebpageList webpageList) {
var panel = new JPanel(new BorderLayout());
var title = new JPanel(new BorderLayout());
title.setOpaque(false);
title.setBorder(JBUI.Borders.empty(8, 0));
title.add(new JBLabel(CodeGPTBundle.get("chatMessageResponseBody.webPagesTitle"))
.withFont(JBUI.Fonts.miniFont()), BorderLayout.LINE_START);
panel.add(title);
var listPanel = new JPanel(new BorderLayout());
listPanel.add(webpageList, BorderLayout.LINE_START);
panel.add(listPanel);
var panel = new JPanel(new BorderLayout());
panel.add(title);
panel.add(listPanel);
return panel;
}
}

View file

@ -1,10 +1,13 @@
package ee.carlrobert.codegpt.completions
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.conversations.Conversation
import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.persona.DEFAULT_PROMPT
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.groups.Tuple
@ -13,7 +16,7 @@ import testsupport.IntegrationTest
class CompletionRequestProviderTest : IntegrationTest() {
fun testChatCompletionRequestWithSystemPromptOverride() {
useOpenAIService()
useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code)
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val conversation = ConversationService.getInstance().startConversation()
val firstMessage = createDummyMessage(500)
@ -42,7 +45,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
}
fun testChatCompletionRequestWithoutSystemPromptOverride() {
useOpenAIService()
useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code)
service<PersonaSettings>().state.selectedPersona.instructions = DEFAULT_PROMPT
val conversation = ConversationService.getInstance().startConversation()
val firstMessage = createDummyMessage(500)
@ -71,7 +74,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
}
fun testChatCompletionRequestRetry() {
useOpenAIService()
useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code)
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val conversation = ConversationService.getInstance().startConversation()
val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500)
@ -98,8 +101,9 @@ class CompletionRequestProviderTest : IntegrationTest() {
}
fun testReducedChatCompletionRequest() {
useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code)
service<PersonaSettings>().state.selectedPersona.instructions = DEFAULT_PROMPT
val conversation = ConversationService.getInstance().startConversation()
val conversation = Conversation()
conversation.addMessage(createDummyMessage(50))
conversation.addMessage(createDummyMessage(100))
conversation.addMessage(createDummyMessage(150))
@ -127,7 +131,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
}
fun testTotalUsageExceededException() {
useOpenAIService()
useOpenAIService(OpenAIChatCompletionModel.GPT_3_5.code)
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(createDummyMessage(1500))
conversation.addMessage(createDummyMessage(1500))

View file

@ -61,6 +61,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
&& panel.tokenDetails.conversationTokens > 0
}
val encodingManager = EncodingManager.getInstance()
assertThat(panel.tokenDetails).extracting(
@ -70,7 +71,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"highlightedTokens")
.containsExactly(
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
encodingManager.countTokens(message.prompt),
encodingManager.countConversationTokens(conversation),
0,
0)
assertThat(panel.conversation)
@ -146,6 +147,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
&& panel.tokenDetails.conversationTokens > 0
}
val encodingManager = EncodingManager.getInstance()
assertThat(panel.tokenDetails).extracting(
@ -155,7 +157,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"highlightedTokens")
.containsExactly(
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
encodingManager.countTokens(message.prompt),
encodingManager.countConversationTokens(conversation),
0,
0)
assertThat(panel.conversation)
@ -219,6 +221,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
&& panel.tokenDetails.conversationTokens > 0
}
val encodingManager = EncodingManager.getInstance()
assertThat(panel.tokenDetails).extracting(
@ -228,7 +231,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"highlightedTokens")
.containsExactly(
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
encodingManager.countTokens(message.prompt),
encodingManager.countConversationTokens(conversation),
0,
0)
assertThat(panel.conversation)
@ -309,6 +312,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
&& panel.tokenDetails.conversationTokens > 0
}
val encodingManager = EncodingManager.getInstance()
assertThat(panel.tokenDetails).extracting(
@ -318,7 +322,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"highlightedTokens")
.containsExactly(
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
encodingManager.countTokens(message.prompt),
encodingManager.countConversationTokens(conversation),
0,
0)
assertThat(panel.conversation)
@ -393,6 +397,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
&& panel.tokenDetails.conversationTokens > 0
}
assertThat(panel.conversation)
.isNotNull()