feat: web search support (#641)

* feat: web search support

* fix: enable web search only for codegpt provider

* fix: checkstyle

* feat: improve list cell design
This commit is contained in:
Carl-Robert 2024-07-30 15:53:45 +03:00 committed by GitHub
parent 1f28bc6217
commit 05f146c405
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 290 additions and 77 deletions

View file

@ -1,11 +1,12 @@
package ee.carlrobert.codegpt.completions;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.GeneralSettingsState;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.you.completion.YouCompletionEventListener;
import ee.carlrobert.llm.client.you.completion.YouSerpResult;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.util.List;
import javax.swing.SwingWorker;
@ -67,7 +68,7 @@ public class CompletionRequestHandler {
protected Void doInBackground() {
var settings = GeneralSettings.getCurrentState();
try {
eventSource = startCall(callParameters, new YouRequestCompletionEventListener());
eventSource = startCall(callParameters, new RequestCompletionEventListener());
} catch (TotalUsageExceededException e) {
completionResponseEventListener.handleTokensExceeded(
callParameters.getConversation(),
@ -86,11 +87,16 @@ public class CompletionRequestHandler {
}
}
class YouRequestCompletionEventListener implements YouCompletionEventListener {
class RequestCompletionEventListener implements CompletionEventListener<String> {
@Override
public void onSerpResults(List<YouSerpResult> results) {
completionResponseEventListener.handleSerpResults(results, callParameters.getMessage());
public void onEvent(String data) {
try {
var event = new ObjectMapper().readValue(data, CodeGPTEvent.class);
completionResponseEventListener.handleCodeGPTEvent(event);
} catch (JsonProcessingException e) {
// ignore
}
}
@Override

View file

@ -217,11 +217,17 @@ public class CompletionRequestProvider {
@Nullable String model,
CallParameters callParameters) {
var configuration = ConfigurationSettings.getCurrentState();
return new OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
var requestBuilder = new OpenAIChatCompletionRequest.Builder(
buildOpenAIMessages(model, callParameters))
.setModel(model)
.setMaxTokens(configuration.getMaxTokens())
.setStream(true)
.setTemperature(configuration.getTemperature()).build();
.setTemperature(configuration.getTemperature());
if (callParameters.getMessage().isWebSearchIncluded()) {
// tri-state boolean
requestBuilder.setWebSearchIncluded(true);
}
return requestBuilder.build();
}
public GoogleCompletionRequest buildGoogleChatCompletionRequest(

View file

@ -2,9 +2,8 @@ package ee.carlrobert.codegpt.completions;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.you.completion.YouSerpResult;
import java.util.List;
public interface CompletionResponseEventListener {
@ -20,6 +19,6 @@ public interface CompletionResponseEventListener {
default void handleCompleted(String fullMessage, CallParameters callParameters) {
}
default void handleSerpResults(List<YouSerpResult> results, Message message) {
default void handleCodeGPTEvent(CodeGPTEvent event) {
}
}

View file

@ -17,6 +17,7 @@ public class Message {
private List<YouSerpResult> serpResults;
private List<String> referencedFilePaths;
private @Nullable String imageFilePath;
private boolean webSearchIncluded;
public Message(String prompt, String response) {
this(prompt);
@ -81,6 +82,14 @@ public class Message {
this.imageFilePath = imageFilePath;
}
public boolean isWebSearchIncluded() {
return webSearchIncluded;
}
public void setWebSearchIncluded(boolean webSearchIncluded) {
this.webSearchIncluded = webSearchIncluded;
}
@Override
public boolean equals(Object obj) {
if (obj == this) {

View file

@ -173,7 +173,8 @@ public class ChatToolWindowTabPanel implements Disposable {
return new ResponsePanel()
.withReloadAction(() -> reloadMessage(message, conversation, conversationType))
.withDeleteAction(() -> removeMessage(message.getId(), conversation))
.addContent(new ChatMessageResponseBody(project, true, this));
.addContent(
new ChatMessageResponseBody(project, true, false, message.isWebSearchIncluded(), this));
}
private void reloadMessage(
@ -244,7 +245,7 @@ public class ChatToolWindowTabPanel implements Disposable {
requestHandler.call(callParameters);
}
private Unit handleSubmit(String text) {
private Unit handleSubmit(String text, boolean webSearchIncluded) {
var message = new Message(text);
var editor = EditorUtil.getSelectedEditor(project);
if (editor != null) {
@ -257,6 +258,7 @@ public class ChatToolWindowTabPanel implements Disposable {
}
}
message.setUserMessage(text);
message.setWebSearchIncluded(webSearchIncluded);
sendMessage(message, ConversationType.DEFAULT);
return Unit.INSTANCE;
}

View file

@ -10,6 +10,7 @@ import ee.carlrobert.codegpt.completions.CompletionResponseEventListener;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.events.CodeGPTEvent;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.codegpt.toolwindow.chat.ui.ChatMessageResponseBody;
import ee.carlrobert.codegpt.toolwindow.chat.ui.ResponsePanel;
@ -17,11 +18,6 @@ import ee.carlrobert.codegpt.toolwindow.chat.ui.textarea.TotalTokensPanel;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.you.completion.YouSerpResult;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import javax.swing.SwingUtilities;
abstract class ToolWindowCompletionResponseEventListener implements
@ -31,7 +27,6 @@ abstract class ToolWindowCompletionResponseEventListener implements
ToolWindowCompletionResponseEventListener.class);
private final StringBuilder messageBuilder = new StringBuilder();
private final Map<UUID, List<YouSerpResult>> serpResultsMapping = new HashMap<>();
private final EncodingManager encodingManager;
private final ConversationService conversationService;
private final ResponsePanel responsePanel;
@ -113,20 +108,11 @@ abstract class ToolWindowCompletionResponseEventListener implements
@Override
public void handleCompleted(String fullMessage, CallParameters callParameters) {
var message = callParameters.getMessage();
conversationService.saveMessage(fullMessage, callParameters);
var serpResults = serpResultsMapping.get(message.getId());
var containsResults = serpResults != null && !serpResults.isEmpty();
if (containsResults) {
message.setSerpResults(serpResults);
}
SwingUtilities.invokeLater(() -> {
try {
responsePanel.enableActions();
if (containsResults) {
responseContainer.displaySerpResults(serpResults);
}
totalTokensPanel.updateUserPromptTokens(textArea.getText());
totalTokensPanel.updateConversationTokens(callParameters.getConversation());
} finally {
@ -136,8 +122,8 @@ abstract class ToolWindowCompletionResponseEventListener implements
}
@Override
public void handleSerpResults(List<YouSerpResult> results, Message message) {
serpResultsMapping.put(message.getId(), results);
public void handleCodeGPTEvent(CodeGPTEvent event) {
responseContainer.displayWebSearchItem(event.getEvent().getDetails());
}
private void stopStreaming(ChatMessageResponseBody responseContainer) {

View file

@ -12,23 +12,25 @@ import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.io.FileUtil;
import com.intellij.openapi.vfs.LocalFileSystem;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.ui.components.JBLabel;
import com.intellij.util.ui.JBUI;
import com.vladsch.flexmark.ast.FencedCodeBlock;
import com.vladsch.flexmark.parser.Parser;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.actions.ActionType;
import ee.carlrobert.codegpt.events.Details;
import ee.carlrobert.codegpt.settings.GeneralSettingsConfigurable;
import ee.carlrobert.codegpt.telemetry.TelemetryAction;
import ee.carlrobert.codegpt.toolwindow.chat.StreamParser;
import ee.carlrobert.codegpt.toolwindow.chat.editor.ResponseEditorPanel;
import ee.carlrobert.codegpt.toolwindow.ui.WebpageList;
import ee.carlrobert.codegpt.ui.UIUtil;
import ee.carlrobert.codegpt.util.EditorUtil;
import ee.carlrobert.codegpt.util.MarkdownUtil;
import ee.carlrobert.llm.client.you.completion.YouSerpResult;
import java.awt.BorderLayout;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.swing.BoxLayout;
import javax.swing.DefaultListModel;
import javax.swing.JPanel;
import javax.swing.JTextPane;
@ -38,6 +40,7 @@ public class ChatMessageResponseBody extends JPanel {
private final Disposable parentDisposable;
private final StreamParser streamParser;
private final boolean readOnly;
private final DefaultListModel<Details> webpageListModel = new DefaultListModel<>();
private ResponseEditorPanel currentlyProcessedEditorPanel;
private JTextPane currentlyProcessedTextPane;
private boolean responseReceived;
@ -50,13 +53,14 @@ public class ChatMessageResponseBody extends JPanel {
Project project,
boolean withGhostText,
Disposable parentDisposable) {
this(project, withGhostText, false, parentDisposable);
this(project, withGhostText, false, false, parentDisposable);
}
public ChatMessageResponseBody(
Project project,
boolean withGhostText,
boolean readOnly,
boolean webSearchIncluded,
Disposable parentDisposable) {
super(new BorderLayout());
this.project = project;
@ -66,6 +70,18 @@ public class ChatMessageResponseBody extends JPanel {
setLayout(new BoxLayout(this, BoxLayout.PAGE_AXIS));
setOpaque(false);
if (webSearchIncluded) {
var title = new JPanel(new BorderLayout());
title.setOpaque(false);
title.setBorder(JBUI.Borders.empty(8, 0));
title.add(new JBLabel(CodeGPTBundle.get("chatMessageResponseBody.webPagesTitle"))
.withFont(JBUI.Fonts.miniFont()), BorderLayout.LINE_START);
add(title);
var listPanel = new JPanel(new BorderLayout());
listPanel.add(new WebpageList(webpageListModel), BorderLayout.LINE_START);
add(listPanel);
}
if (withGhostText) {
prepareProcessingText(!readOnly);
currentlyProcessedTextPane.setText(
@ -136,18 +152,6 @@ public class ChatMessageResponseBody extends JPanel {
}
}
public void displaySerpResults(List<YouSerpResult> serpResults) {
var html = getSearchResultsHtml(serpResults);
if (responseReceived) {
add(createTextPane(html, false));
} else {
if (currentlyProcessedTextPane == null) {
prepareProcessingText(false);
}
currentlyProcessedTextPane.setText(html);
}
}
public void clear() {
removeAll();
@ -161,21 +165,6 @@ public class ChatMessageResponseBody extends JPanel {
revalidate();
}
private String getSearchResultsHtml(List<YouSerpResult> serpResults) {
var titles = serpResults.stream()
.map(result -> format(
"<li style=\"margin-bottom: 4px;\"><a href=\"%s\">%s</a></li>",
result.getUrl(),
result.getName()))
.collect(Collectors.joining());
return format(
"<html>"
+ "<p><strong>Search results:</strong></p>"
+ "<ol>%s</ol>"
+ "</html>",
titles);
}
private void processResponse(String markdownInput, boolean codeResponse, boolean caretVisible) {
responseReceived = true;
@ -239,4 +228,8 @@ public class ChatMessageResponseBody extends JPanel {
textPane.setBorder(JBUI.Borders.empty());
return textPane;
}
public void displayWebSearchItem(Details details) {
webpageListModel.addElement(details);
}
}

View file

@ -59,7 +59,13 @@ public class UserMessagePanel extends JPanel {
Project project,
String prompt,
Disposable parentDisposable) {
return new ChatMessageResponseBody(project, false, true, parentDisposable).withResponse(prompt);
return new ChatMessageResponseBody(
project,
false,
true,
false,
parentDisposable)
.withResponse(prompt);
}
private JBLabel createDisplayNameLabel() {