Add interactive total token count label, codebase refactoring

This commit is contained in:
Carl-Robert Linnupuu 2023-11-14 13:27:15 +02:00
parent d8e5e18998
commit ec3120a5e6
31 changed files with 804 additions and 322 deletions

View file

@ -1,28 +1,28 @@
package ee.carlrobert.codegpt.toolwindow.chat;
import static com.intellij.openapi.ui.Messages.OK;
import static ee.carlrobert.codegpt.util.SwingUtils.createScrollPaneWithSmartScroller;
import static ee.carlrobert.codegpt.util.ThemeUtils.getPanelBackgroundColor;
import static java.lang.String.format;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.editor.EditorFactory;
import com.intellij.openapi.editor.event.EditorFactoryEvent;
import com.intellij.openapi.editor.event.EditorFactoryListener;
import com.intellij.openapi.editor.event.SelectionEvent;
import com.intellij.openapi.editor.event.SelectionListener;
import com.intellij.openapi.editor.impl.EditorImpl;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.roots.ui.componentsList.components.ScrollablePanel;
import com.intellij.openapi.roots.ui.componentsList.layout.VerticalStackLayout;
import com.intellij.ui.DocumentAdapter;
import com.intellij.ui.JBColor;
import com.intellij.ui.ScrollPaneFactory;
import com.intellij.ui.components.JBCheckBox;
import com.intellij.util.messages.MessageBusConnection;
import com.intellij.util.ui.JBUI;
import com.intellij.util.ui.JBUI.Borders;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.actions.ActionType;
import ee.carlrobert.codegpt.completions.CompletionRequestHandler;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.completions.ToolWindowCompletionEventListener;
import ee.carlrobert.codegpt.completions.you.YouSubscriptionNotifier;
import ee.carlrobert.codegpt.completions.you.YouUserManager;
import ee.carlrobert.codegpt.completions.you.auth.SignedOutNotifier;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
@ -31,34 +31,29 @@ import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
import ee.carlrobert.codegpt.settings.state.SettingsState;
import ee.carlrobert.codegpt.settings.state.YouSettingsState;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.codegpt.toolwindow.ModelIconLabel;
import ee.carlrobert.codegpt.toolwindow.chat.components.ChatMessageResponseBody;
import ee.carlrobert.codegpt.toolwindow.chat.components.ResponsePanel;
import ee.carlrobert.codegpt.toolwindow.chat.components.SmartScroller;
import ee.carlrobert.codegpt.toolwindow.chat.components.TotalTokensPanel;
import ee.carlrobert.codegpt.toolwindow.chat.components.UserMessagePanel;
import ee.carlrobert.codegpt.toolwindow.chat.components.UserPromptTextArea;
import ee.carlrobert.codegpt.toolwindow.chat.components.UserPromptTextAreaHeader;
import ee.carlrobert.codegpt.toolwindow.chat.components.YouProCheckbox;
import ee.carlrobert.codegpt.util.EditorUtils;
import ee.carlrobert.codegpt.util.OverlayUtils;
import ee.carlrobert.codegpt.util.SwingUtils;
import ee.carlrobert.codegpt.util.file.FileUtils;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.you.completion.YouSerpResult;
import java.awt.BorderLayout;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import javax.swing.BoxLayout;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTextPane;
import javax.swing.ScrollPaneConstants;
import javax.swing.SwingUtilities;
import javax.swing.event.DocumentEvent;
import javax.swing.text.BadLocationException;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@ -67,18 +62,18 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
private static final Logger LOG = Logger.getInstance(BaseChatToolWindowTabPanel.class);
private final SettingsState settings;
private final YouUserManager youUserManager;
private final boolean useContextualSearch;
private final JPanel rootPanel;
private final ScrollablePanel scrollablePanel;
private final Map<UUID, JPanel> visibleMessagePanels = new HashMap<>();
private final Map<UUID, List<YouSerpResult>> serpResultsMapping = new HashMap<>();
private final JBCheckBox gpt4CheckBox;
protected final TotalTokensPanel totalTokensPanel;
protected final Project project;
protected final UserPromptTextArea userPromptTextArea;
protected final ConversationService conversationService;
protected final ChatToolWindowScrollablePanel toolWindowScrollablePanel;
private final EncodingManager encodingManager;
private boolean streaming;
protected @Nullable Conversation conversation;
protected abstract JComponent getLandingView();
@ -86,22 +81,24 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
public BaseChatToolWindowTabPanel(@NotNull Project project, boolean useContextualSearch) {
this.project = project;
this.useContextualSearch = useContextualSearch;
this.conversationService = ConversationService.getInstance();
this.scrollablePanel = new ScrollablePanel(new VerticalStackLayout());
this.userPromptTextArea = new UserPromptTextArea(this::handleSubmit);
this.gpt4CheckBox = new YouProCheckbox(project);
this.settings = SettingsState.getInstance();
this.youUserManager = YouUserManager.getInstance();
this.rootPanel = createRootPanel();
conversationService = ConversationService.getInstance();
encodingManager = EncodingManager.getInstance();
settings = SettingsState.getInstance();
toolWindowScrollablePanel = new ChatToolWindowScrollablePanel(settings);
gpt4CheckBox = new YouProCheckbox(project);
userPromptTextArea = new UserPromptTextArea(this::handleSubmit, getUserPromptDocumentAdapter());
totalTokensPanel = new TotalTokensPanel(
null,
userPromptTextArea.getText(),
EditorUtils.getSelectedEditorSelectedText(project));
rootPanel = createRootPanel();
addSelectionListeners();
userPromptTextArea.requestFocusInWindow();
userPromptTextArea.requestFocus();
}
public void requestFocusForTextArea() {
userPromptTextArea.focus();
}
@Override
public JPanel getContent() {
return rootPanel;
@ -119,13 +116,7 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
@Override
public void displayLandingView() {
scrollablePanel.removeAll();
scrollablePanel.add(getLandingView());
if (settings.getSelectedService() == ServiceType.YOU &&
(!youUserManager.isAuthenticated() || !youUserManager.isSubscribed())) {
scrollablePanel.add(new ResponsePanel().addContent(createYouCouponTextPane()));
}
revalidateScrollablePanel();
toolWindowScrollablePanel.displayLandingView(getLandingView());
}
@Override
@ -136,36 +127,53 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
@Override
public void sendMessage(Message message) {
streaming = true;
if (conversation == null) {
conversation = conversationService.startConversation();
}
var messageWrapper = createNewMessageWrapper(message.getId());
messageWrapper.add(new UserMessagePanel(project, message, this));
var messagePanel = toolWindowScrollablePanel.addMessage(message.getId());
messagePanel.add(new UserMessagePanel(project, message, this));
var responsePanel = new ResponsePanel()
.withReloadAction(() -> reloadMessage(message, conversation))
.withDeleteAction(() -> removeMessage(message.getId(), messageWrapper, conversation))
.withDeleteAction(() -> removeMessage(message.getId(), conversation))
.addContent(new ChatMessageResponseBody(project, true, this));
messageWrapper.add(responsePanel);
messagePanel.add(responsePanel);
totalTokensPanel.updateUserPromptTokens(message.getPrompt());
call(conversation, message, responsePanel, false);
}
@Override
public TokenDetails getTokenDetails() {
return totalTokensPanel.getTokenDetails();
}
@Override
public void dispose() {
}
public void requestFocusForTextArea() {
userPromptTextArea.focus();
}
public void updateConversationTokens() {
totalTokensPanel.updateConversationTokens(conversation);
}
public boolean isStreaming() {
return streaming;
}
protected void reloadMessage(Message message, Conversation conversation) {
ResponsePanel responsePanel = null;
try {
responsePanel = (ResponsePanel) Arrays.stream(
visibleMessagePanels.get(message.getId()).getComponents())
.filter(component -> component instanceof ResponsePanel)
.findFirst().orElseThrow();
responsePanel = toolWindowScrollablePanel.getMessageResponsePanel(message.getId());
((ChatMessageResponseBody) responsePanel.getContent()).clear();
revalidateScrollablePanel();
toolWindowScrollablePanel.update();
} catch (Exception e) {
throw new RuntimeException("Couldn't delete the existing message component", e);
throw new RuntimeException("Could not delete the existing message component", e);
} finally {
LOG.debug("Reloading message: " + message.getId());
@ -175,22 +183,16 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
call(conversation, message, responsePanel, true);
}
totalTokensPanel.updateConversationTokens(conversation);
TelemetryAction.IDE_ACTION.createActionMessage()
.property("action", ActionType.RELOAD_MESSAGE.name())
.send();
}
}
private void revalidateScrollablePanel() {
scrollablePanel.repaint();
scrollablePanel.revalidate();
}
protected void removeMessage(UUID messageId, JPanel messageWrapper, Conversation conversation) {
scrollablePanel.remove(messageWrapper);
revalidateScrollablePanel();
visibleMessagePanels.remove(messageId);
protected void removeMessage(UUID messageId, Conversation conversation) {
toolWindowScrollablePanel.removeMessage(messageId);
conversation.removeMessage(messageId);
conversationService.saveConversation(conversation);
@ -201,18 +203,10 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
}
}
protected JPanel createNewMessageWrapper(UUID messageId) {
var messageWrapper = new JPanel();
messageWrapper.setLayout(new BoxLayout(messageWrapper, BoxLayout.PAGE_AXIS));
scrollablePanel.add(messageWrapper);
revalidateScrollablePanel();
visibleMessagePanels.put(messageId, messageWrapper);
return messageWrapper;
}
protected void clearWindow() {
scrollablePanel.removeAll();
revalidateScrollablePanel();
toolWindowScrollablePanel.clearAll();
totalTokensPanel.updateConversationTokens(conversation);
updateConversationTokens();
}
private void call(
@ -257,11 +251,17 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
}
}
private static JScrollPane createScrollPane(ScrollablePanel scrollablePanel) {
var scrollPane = ScrollPaneFactory.createScrollPane(scrollablePanel, true);
scrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
new SmartScroller(scrollPane);
return scrollPane;
private JPanel createUserPromptPanel() {
var panel = new JPanel(new BorderLayout());
panel.setBorder(JBUI.Borders.compound(
JBUI.Borders.customLine(JBColor.border(), 1, 0, 0, 0),
JBUI.Borders.empty(8)));
panel.setBackground(getPanelBackgroundColor());
panel.add(
new UserPromptTextAreaHeader(project, settings, totalTokensPanel, gpt4CheckBox),
BorderLayout.NORTH);
panel.add(userPromptTextArea, BorderLayout.SOUTH);
return panel;
}
private JPanel createRootPanel() {
@ -272,88 +272,20 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
gbc.weightx = 1;
gbc.gridx = 0;
gbc.gridy = 0;
rootPanel.add(createScrollPane(scrollablePanel), gbc);
var wrapper = new JPanel(new BorderLayout());
wrapper.setBorder(JBUI.Borders.compound(
JBUI.Borders.customLine(JBColor.border(), 1, 0, 0, 0),
JBUI.Borders.empty(8)));
wrapper.setBackground(getPanelBackgroundColor());
wrapper.add(createPromptTextAreaHeader(), BorderLayout.NORTH);
wrapper.add(userPromptTextArea, BorderLayout.SOUTH);
rootPanel.add(createScrollPaneWithSmartScroller(toolWindowScrollablePanel), gbc);
gbc.weighty = 0;
gbc.fill = GridBagConstraints.HORIZONTAL;
gbc.gridy = 1;
rootPanel.add(wrapper, gbc);
rootPanel.add(createUserPromptPanel(), gbc);
return rootPanel;
}
private JPanel createPromptTextAreaHeader() {
var header = new JPanel(new BorderLayout());
header.setBackground(getPanelBackgroundColor());
header.setBorder(JBUI.Borders.emptyBottom(8));
var model = settings.getModel();
if ("YouCode".equals(model)) {
var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect();
subscribeToYouModelChangeTopic();
subscribeToYouSubscriptionTopic(messageBusConnection);
subscribeToSignedOutTopic(messageBusConnection);
header.add(gpt4CheckBox, BorderLayout.LINE_START);
}
header.add(JBUI.Panels
.simplePanel(
new ModelIconLabel(settings.getSelectedService().getCompletionCode(),
model))
.withBorder(Borders.emptyRight(4))
.withBackground(getPanelBackgroundColor()), BorderLayout.LINE_END);
return header;
}
private void subscribeToYouModelChangeTopic() {
project.getMessageBus()
.connect()
.subscribe(
YouModelChangeNotifier.YOU_MODEL_CHANGE_NOTIFIER_TOPIC,
(YouModelChangeNotifier) gpt4CheckBox::setSelected);
}
private void subscribeToSignedOutTopic(MessageBusConnection messageBusConnection) {
messageBusConnection.subscribe(
SignedOutNotifier.SIGNED_OUT_TOPIC,
(SignedOutNotifier) () -> gpt4CheckBox.setEnabled(false));
}
private void subscribeToYouSubscriptionTopic(MessageBusConnection messageBusConnection) {
messageBusConnection.subscribe(
YouSubscriptionNotifier.SUBSCRIPTION_TOPIC,
(YouSubscriptionNotifier) () -> {
displayLandingView();
gpt4CheckBox.setEnabled(true);
});
}
private JTextPane createYouCouponTextPane() {
var textPane = SwingUtils.createTextPane(
"<html>\n"
+ "<body>\n"
+ " <p style=\"margin: 4px 0;\">Use CodeGPT coupon for free month of GPT-4.</p>\n"
+ " <p style=\"margin: 4px 0;\">\n"
+ " <a href=\"https://you.com/plans\">Sign up here</a>\n"
+ " </p>\n"
+ "</body>\n"
+ "</html>"
);
textPane.setBackground(getPanelBackgroundColor());
textPane.setFocusable(false);
return textPane;
}
private class ChatToolWindowCompletionEventListener implements ToolWindowCompletionEventListener {
private final Logger LOG = Logger.getInstance(ChatToolWindowCompletionEventListener.class);
private final StringBuilder messageBuilder = new StringBuilder();
private final ResponsePanel responsePanel;
private final ChatMessageResponseBody responseContainer;
@ -363,11 +295,18 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
}
@Override
public void handleMessage(String message) {
public void handleMessage(String partialMessage) {
try {
LOG.debug(message);
LOG.debug(partialMessage);
ApplicationManager.getApplication()
.invokeLater(() -> responseContainer.update(message));
.invokeLater(() -> {
responseContainer.update(partialMessage);
messageBuilder.append(partialMessage);
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
totalTokensPanel.update(
totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
});
} catch (Exception e) {
responseContainer.displayDefaultError();
throw new RuntimeException("Error while updating the content", e);
@ -414,18 +353,24 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
Message message,
Conversation conversation,
boolean isRetry) {
responsePanel.enableActions();
conversationService.saveMessage(fullMessage, message, conversation, isRetry);
stopStreaming(responseContainer);
try {
responsePanel.enableActions();
conversationService.saveMessage(fullMessage, message, conversation, isRetry);
var serpResults = serpResultsMapping.get(message.getId());
var containsResults = serpResults != null && !serpResults.isEmpty();
if (YouSettingsState.getInstance().isDisplayWebSearchResults() && containsResults) {
responseContainer.displaySerpResults(serpResults);
}
var serpResults = serpResultsMapping.get(message.getId());
var containsResults = serpResults != null && !serpResults.isEmpty();
if (YouSettingsState.getInstance().isDisplayWebSearchResults() && containsResults) {
responseContainer.displaySerpResults(serpResults);
}
if (containsResults) {
message.setSerpResults(serpResults);
if (containsResults) {
message.setSerpResults(serpResults);
}
totalTokensPanel.updateUserPromptTokens(userPromptTextArea.getText());
totalTokensPanel.updateConversationTokens(conversation);
} finally {
stopStreaming(responseContainer);
}
}
@ -435,10 +380,51 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan
}
private void stopStreaming(ChatMessageResponseBody responseContainer) {
SwingUtilities.invokeLater(() -> {
userPromptTextArea.setSubmitEnabled(true);
responseContainer.hideCarets();
});
streaming = false;
userPromptTextArea.setSubmitEnabled(true);
responseContainer.hideCarets();
}
}
private void addSelectionListeners() {
var editorFactory = EditorFactory.getInstance();
for (var editor : editorFactory.getAllEditors()) {
editor.getSelectionModel().addSelectionListener(getSelectionListener());
}
editorFactory.addEditorFactoryListener(new EditorFactoryListener() {
@Override
public void editorCreated(@NotNull EditorFactoryEvent event) {
event.getEditor().getSelectionModel().addSelectionListener(getSelectionListener());
}
}, this);
}
private SelectionListener getSelectionListener() {
return new SelectionListener() {
@Override
public void selectionChanged(@NotNull SelectionEvent e) {
var selectedText = e.getEditor().getDocument().getText(e.getNewRange());
totalTokensPanel.updateHighlightedTokens(selectedText);
}
};
}
private DocumentAdapter getUserPromptDocumentAdapter() {
return new DocumentAdapter() {
@Override
protected void textChanged(@NotNull DocumentEvent event) {
try {
if (!streaming) {
var document = event.getDocument();
var text = document.getText(
document.getStartPosition().getOffset(),
document.getEndPosition().getOffset() - 1);
totalTokensPanel.updateUserPromptTokens(text);
}
} catch (BadLocationException ex) {
LOG.error("Something went wrong while processing user input tokens", ex);
}
}
};
}
}