1.0.3 - Server-Sent Events

This commit is contained in:
Carl-Robert Linnupuu 2023-02-16 02:45:58 +00:00
parent c0f340ecad
commit d5afb4144b
10 changed files with 221 additions and 79 deletions

View file

@ -24,9 +24,11 @@ public class AskAction extends AnAction {
var toolWindow = ToolWindowManager.getInstance(project).getToolWindow("ChatGPT");
if (toolWindow != null) {
toolWindow.show();
toolWindow.setTitle("");
var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class);
ApiClient.getInstance().clearQueries();
toolWindowService.getScrollablePanel().removeAll();
toolWindowService.removeAll();
toolWindowService.paintLandingView();
}
}
}

View file

@ -23,10 +23,10 @@ public abstract class BaseAction extends AnAction {
initToolWindow(ToolWindowManager.getInstance(project).getToolWindow("ChatGPT"));
var selectedText = editor.getSelectionModel().getSelectedText();
var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class);
var scrollablePanel = toolWindowService.getScrollablePanel();
ApiClient.getInstance().clearQueries();
scrollablePanel.removeAll();
toolWindowService.sendMessage(selectedText, getPrompt(selectedText));
toolWindowService.removeAll();
toolWindowService.paintUserMessage(selectedText);
toolWindowService.sendMessage(getPrompt(selectedText), null);
}
}

View file

@ -3,14 +3,9 @@ package ee.carlrobert.chatgpt.client;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.chatgpt.settings.SettingsState;
import ee.carlrobert.chatgpt.client.response.ApiError;
import ee.carlrobert.chatgpt.client.response.ApiResponse;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -26,34 +21,31 @@ public final class ApiClient {
private ApiClient() {
}
public void getCompletionsAsync(String prompt, Consumer<ApiResponse> onSuccess, Consumer<ApiError> onError) {
/*var query = new StringBuilder(
"You are ChatGPT, a large language model trained by OpenAI. You answer as concisely as possible for each response (e.g. dont be verbose). It is very important that you answer as concisely as possible, so please remember this.\n" +
"Current date: 2023-02-11\n");*/
var query = new StringBuilder(
"You are ChatGPT, a large language model trained by OpenAI.\n");
for (var entry : queries) {
public void getCompletionsAsync(String prompt, Consumer<String> onMessage) {
try {
var query = new StringBuilder(
"You are ChatGPT, a large language model trained by OpenAI.\n");
for (var entry : queries) {
query.append("User:\n")
.append(entry.getKey())
.append("<|im_end|>\n")
.append("\n")
.append("ChatGPT:\n")
.append(entry.getValue())
.append("<|im_end|>\n")
.append("\n");
}
query.append("User:\n")
.append(entry.getKey())
.append(prompt)
.append("<|im_end|>\n")
.append("\n")
.append("ChatGPT:\n")
.append(entry.getValue())
.append("<|im_end|>\n")
.append("\n");
}
query.append("User:\n")
.append(prompt)
.append("<|im_end|>\n")
.append("\n")
.append("ChatGPT:\n");
try {
.append("ChatGPT:\n");
var request = HttpRequest.newBuilder()
var req = HttpRequest.newBuilder()
.uri(URI.create("https://api.openai.com/v1/completions"))
.header("Authorization", "Bearer " + SettingsState.getInstance().secretKey)
.timeout(Duration.ofMinutes(1))
.header("Accept", "text/event-stream")
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + SettingsState.getInstance().secretKey)
.POST(HttpRequest.BodyPublishers.ofString(objectMapper
.writerWithDefaultPrettyPrinter()
.writeValueAsString(Map.of(
@ -61,24 +53,29 @@ public final class ApiClient {
"stop", List.of("<|im_end|>"),
"prompt", query.toString(),
"max_tokens", 400,
"temperature", 1.0
"temperature", 1.0,
"stream", true
))))
.build();
client.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenAccept(response -> {
try {
var mappedResponse = objectMapper.readValue(response.body(), ApiResponse.class);
if (mappedResponse.getError() == null) {
queries.add(Map.entry(prompt, mappedResponse.getChoices().get(0).getText()));
onSuccess.accept(mappedResponse);
} else {
onError.accept(mappedResponse.getError());
}
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
this.client.sendAsync(req, respInfo ->
{
if (respInfo.statusCode() == 200) {
return new Subscriber((messageData ->
onMessage.accept(messageData.getChoices().get(0).getText())),
(finalMsg) -> queries.add(Map.entry(prompt, finalMsg)));
} else if (respInfo.statusCode() == 401) {
onMessage.accept("Incorrect API key provided.\n" +
"You can find your API key at https://platform.openai.com/account/api-keys.");
throw new IllegalArgumentException();
} else {
onMessage.accept("Something went wrong. Please try again later.");
clearQueries();
throw new RuntimeException();
}
});
} catch (IOException e) {
} catch (JsonProcessingException e) {
onMessage.accept("Something went wrong. Please try again later.");
throw new RuntimeException(e);
}
}

View file

@ -0,0 +1,113 @@
package ee.carlrobert.chatgpt.client;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.fasterxml.jackson.databind.ObjectMapper;
import ee.carlrobert.chatgpt.client.response.ApiResponse;
import java.net.http.HttpResponse;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Flow;
import java.util.function.Consumer;
import java.util.regex.Pattern;
public class Subscriber implements HttpResponse.BodySubscriber<Void> {
protected static final Pattern dataLinePattern = Pattern.compile("^data: ?(.*)$");
protected static ApiResponse extractMessageData(String[] messageLines) {
var responseBuilder = new StringBuilder();
for (var line : messageLines) {
var matcher = dataLinePattern.matcher(line);
if (matcher.matches()) {
responseBuilder.append(matcher.group(1));
}
}
try {
return new ObjectMapper().readValue(responseBuilder.toString(), ApiResponse.class);
} catch (Exception e) {
throw new RuntimeException("Couldn't read the payload", e);
}
}
protected final Consumer<? super ApiResponse> messageDataConsumer;
protected final CompletableFuture<Void> future;
protected volatile Flow.Subscription subscription;
protected volatile String deferredText;
private final Consumer<String> onComplete;
private final StringBuilder msgBuilder = new StringBuilder();
public Subscriber(Consumer<? super ApiResponse> messageDataConsumer, Consumer<String> onComplete) {
this.messageDataConsumer = messageDataConsumer;
this.future = new CompletableFuture<>();
this.subscription = null;
this.deferredText = null;
this.onComplete = onComplete;
}
@Override
public void onSubscribe(Flow.Subscription subscription) {
this.subscription = subscription;
try {
this.deferredText = "";
this.subscription.request(1);
} catch (Exception e) {
this.future.completeExceptionally(e);
this.subscription.cancel();
}
}
@Override
public void onNext(List<ByteBuffer> buffers) {
try {
var deferredText = this.deferredText;
for (var buffer : buffers) {
var s = deferredText + UTF_8.decode(buffer);
var tokens = s.split("\n\n", -1);
for (var i = 0; i < tokens.length - 1; i++) {
var message = tokens[i];
var data = extractMessageData(message.split("\n"));
var choice = data.getChoices().get(0); // TODO: Is there only one choice per response?
if ("stop".equals(choice.getFinish_reason())) {
onComplete();
} else {
msgBuilder.append(choice.getText());
}
this.messageDataConsumer.accept(data);
}
deferredText = tokens[tokens.length - 1];
}
this.deferredText = deferredText;
this.subscription.request(1);
} catch (Exception e) {
this.future.completeExceptionally(e);
this.subscription.cancel();
}
}
@Override
public void onError(Throwable e) {
this.future.completeExceptionally(e);
}
@Override
public void onComplete() {
try {
this.future.complete(null);
this.onComplete.accept(msgBuilder.toString());
} catch (Exception e) {
this.future.completeExceptionally(e);
}
}
@Override
public CompletionStage<Void> getBody() {
return this.future;
}
}

View file

@ -26,7 +26,8 @@ public class ChatGptToolWindow {
public void handleSubmit() {
var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class);
var searchText = textField.getText();
toolWindowService.sendMessage(searchText, searchText, this::scrollToBottom);
toolWindowService.paintUserMessage(searchText);
toolWindowService.sendMessage(searchText, this::scrollToBottom);
textField.setText("");
scrollToBottom();
}
@ -60,5 +61,6 @@ public class ChatGptToolWindow {
var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class);
toolWindowService.setScrollablePanel(scrollablePanel);
toolWindowService.paintLandingView();
}
}

View file

@ -5,16 +5,22 @@ import static ee.carlrobert.chatgpt.toolwindow.ToolWindowUtil.createTextArea;
import static ee.carlrobert.chatgpt.toolwindow.ToolWindowUtil.justifyLeft;
import com.intellij.openapi.roots.ui.componentsList.components.ScrollablePanel;
import ee.carlrobert.chatgpt.client.ApiClient;
import ee.carlrobert.chatgpt.EmptyCallback;
import ee.carlrobert.chatgpt.toolwindow.components.Loader;
import ee.carlrobert.chatgpt.client.ApiClient;
import java.awt.GridBagLayout;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nullable;
import javax.swing.Box;
import javax.swing.ImageIcon;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.SwingConstants;
public class ToolWindowService {
private ScrollablePanel scrollablePanel;
private boolean isLandingViewVisible;
public void setScrollablePanel(ScrollablePanel scrollablePanel) {
this.scrollablePanel = scrollablePanel;
@ -24,38 +30,61 @@ public class ToolWindowService {
return scrollablePanel;
}
public void sendMessage(String userMessage, String prompt) {
sendMessage(userMessage, prompt, null);
public void paintUserMessage(String userMessage) {
if (isLandingViewVisible) {
removeAll();
}
scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/user-icon.png")), "User:")));
scrollablePanel.add(Box.createVerticalStrut(8));
scrollablePanel.add(createTextArea(userMessage, true));
}
public void sendMessage(String userMessage, String prompt, @Nullable EmptyCallback onSuccess) {
scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/user-icon.png")), "User")));
scrollablePanel.add(Box.createVerticalStrut(8));
scrollablePanel.add(createTextArea(userMessage, true, true));
public void sendMessage(String prompt, @Nullable EmptyCallback scrollToBottom) {
scrollablePanel.add(Box.createVerticalStrut(16));
scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/chatgpt-icon.png")), "ChatGPT")));
scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/chatgpt-icon.png")), "ChatGPT:")));
scrollablePanel.add(Box.createVerticalStrut(8));
var loader = new Loader();
scrollablePanel.add(justifyLeft(loader.getComponent()));
loader.startLoading();
scrollablePanel.add(Box.createVerticalStrut(4));
var textArea = createTextArea("", false);
scrollablePanel.add(textArea);
ApiClient.getInstance().getCompletionsAsync(prompt, response -> {
loader.stopLoading();
scrollablePanel.add(Box.createVerticalStrut(4));
for (var choice : response.getChoices()) {
scrollablePanel.add(createTextArea(choice.getText().trim(), false, true));
ApiClient.getInstance().getCompletionsAsync(prompt, (message) -> {
textArea.append(message);
if (scrollToBottom != null) {
scrollToBottom.call();
}
scrollablePanel.add(Box.createVerticalStrut(32));
if (onSuccess != null) {
onSuccess.call();
}
}, apiError -> {
loader.stopLoading();
scrollablePanel.add(Box.createVerticalStrut(4));
scrollablePanel.add(createTextArea(apiError.getMessage(), false, true));
scrollablePanel.add(Box.createVerticalStrut(32));
});
scrollablePanel.add(Box.createVerticalStrut(16));
}
public void paintLandingView() {
isLandingViewVisible = true;
var imageIconPanel = new JPanel();
imageIconPanel.setLayout(new GridBagLayout());
var imageIconLabel = new JLabel(new ImageIcon(Objects.requireNonNull(getClass().getResource("/icons/sun-icon.png"))));
imageIconLabel.setHorizontalAlignment(JLabel.CENTER);
imageIconPanel.add(imageIconLabel);
scrollablePanel.add(imageIconPanel);
scrollablePanel.add(Box.createVerticalStrut(16));
var questions = List.of("How do I make an HTTP request in Javascript?",
"What is the difference between px, dip, dp, and sp?",
"How do I undo the most recent local commits in Git?",
"What is the difference between stack and heap?");
for (var question : questions) {
var panel = new JPanel();
panel.setLayout(new GridBagLayout());
var label = new JLabel(question, SwingConstants.CENTER);
label.setHorizontalAlignment(JLabel.CENTER);
panel.add(label);
scrollablePanel.add(panel);
scrollablePanel.add(Box.createVerticalStrut(16));
}
}
public void removeAll() {
isLandingViewVisible = false;
scrollablePanel.removeAll();
}
}

View file

@ -11,16 +11,14 @@ import javax.swing.JTextArea;
public class ToolWindowUtil {
public static JTextArea createTextArea(String selectedText, boolean isItalicFont, boolean transparentBackground) {
public static JTextArea createTextArea(String selectedText, boolean isItalicFont) {
var textArea = new JTextArea();
textArea.append(selectedText);
textArea.setLineWrap(true);
textArea.setEditable(false);
textArea.setFont(createFont(isItalicFont, textArea.getFont().getSize()));
textArea.setFont(new Font("Tahoma", isItalicFont ? Font.ITALIC : Font.PLAIN, textArea.getFont().getSize()));
textArea.setWrapStyleWord(true);
if (transparentBackground) {
textArea.setBackground(JBColor.background());
}
textArea.setBackground(JBColor.PanelBackground);
// textArea.setBorder(new MatteBorder(0, 2, 0, 0, JBColor.RED));
return textArea;
}