From d5afb4144b99fdacdb4ca21d10450474909ef484 Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Thu, 16 Feb 2023 02:45:58 +0000 Subject: [PATCH] 1.0.3 - Server-Sent Events --- build.gradle.kts | 2 +- .../carlrobert/chatgpt/action/AskAction.java | 4 +- .../carlrobert/chatgpt/action/BaseAction.java | 6 +- .../carlrobert/chatgpt/client/ApiClient.java | 77 ++++++------ .../carlrobert/chatgpt/client/Subscriber.java | 113 ++++++++++++++++++ .../chatgpt/toolwindow/ChatGptToolWindow.java | 4 +- .../chatgpt/toolwindow/ToolWindowService.java | 85 ++++++++----- .../chatgpt/toolwindow/ToolWindowUtil.java | 8 +- src/main/resources/META-INF/plugin.xml | 1 + src/main/resources/icons/sun-icon.png | Bin 0 -> 830 bytes 10 files changed, 221 insertions(+), 79 deletions(-) create mode 100644 src/main/java/ee/carlrobert/chatgpt/client/Subscriber.java create mode 100644 src/main/resources/icons/sun-icon.png diff --git a/build.gradle.kts b/build.gradle.kts index fef3bea9..f88222f1 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -4,7 +4,7 @@ plugins { } group = "ee.carlrobert" -version = "1.0.2" +version = "1.0.3" repositories { mavenCentral() diff --git a/src/main/java/ee/carlrobert/chatgpt/action/AskAction.java b/src/main/java/ee/carlrobert/chatgpt/action/AskAction.java index de902597..dc69b538 100644 --- a/src/main/java/ee/carlrobert/chatgpt/action/AskAction.java +++ b/src/main/java/ee/carlrobert/chatgpt/action/AskAction.java @@ -24,9 +24,11 @@ public class AskAction extends AnAction { var toolWindow = ToolWindowManager.getInstance(project).getToolWindow("ChatGPT"); if (toolWindow != null) { toolWindow.show(); + toolWindow.setTitle(""); var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class); ApiClient.getInstance().clearQueries(); - toolWindowService.getScrollablePanel().removeAll(); + toolWindowService.removeAll(); + toolWindowService.paintLandingView(); } } } diff --git a/src/main/java/ee/carlrobert/chatgpt/action/BaseAction.java b/src/main/java/ee/carlrobert/chatgpt/action/BaseAction.java index 6a1168ec..20b3abed 100644 --- a/src/main/java/ee/carlrobert/chatgpt/action/BaseAction.java +++ b/src/main/java/ee/carlrobert/chatgpt/action/BaseAction.java @@ -23,10 +23,10 @@ public abstract class BaseAction extends AnAction { initToolWindow(ToolWindowManager.getInstance(project).getToolWindow("ChatGPT")); var selectedText = editor.getSelectionModel().getSelectedText(); var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class); - var scrollablePanel = toolWindowService.getScrollablePanel(); ApiClient.getInstance().clearQueries(); - scrollablePanel.removeAll(); - toolWindowService.sendMessage(selectedText, getPrompt(selectedText)); + toolWindowService.removeAll(); + toolWindowService.paintUserMessage(selectedText); + toolWindowService.sendMessage(getPrompt(selectedText), null); } } diff --git a/src/main/java/ee/carlrobert/chatgpt/client/ApiClient.java b/src/main/java/ee/carlrobert/chatgpt/client/ApiClient.java index 3ccdf880..52f537b4 100644 --- a/src/main/java/ee/carlrobert/chatgpt/client/ApiClient.java +++ b/src/main/java/ee/carlrobert/chatgpt/client/ApiClient.java @@ -3,14 +3,9 @@ package ee.carlrobert.chatgpt.client; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import ee.carlrobert.chatgpt.settings.SettingsState; -import ee.carlrobert.chatgpt.client.response.ApiError; -import ee.carlrobert.chatgpt.client.response.ApiResponse; -import java.io.IOException; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -26,34 +21,31 @@ public final class ApiClient { private ApiClient() { } - public void getCompletionsAsync(String prompt, Consumer onSuccess, Consumer onError) { - /*var query = new StringBuilder( - "You are ChatGPT, a large language model trained by OpenAI. You answer as concisely as possible for each response (e.g. don’t be verbose). It is very important that you answer as concisely as possible, so please remember this.\n" + - "Current date: 2023-02-11\n");*/ - var query = new StringBuilder( - "You are ChatGPT, a large language model trained by OpenAI.\n"); - for (var entry : queries) { + public void getCompletionsAsync(String prompt, Consumer onMessage) { + try { + var query = new StringBuilder( + "You are ChatGPT, a large language model trained by OpenAI.\n"); + for (var entry : queries) { + query.append("User:\n") + .append(entry.getKey()) + .append("<|im_end|>\n") + .append("\n") + .append("ChatGPT:\n") + .append(entry.getValue()) + .append("<|im_end|>\n") + .append("\n"); + } query.append("User:\n") - .append(entry.getKey()) + .append(prompt) .append("<|im_end|>\n") .append("\n") - .append("ChatGPT:\n") - .append(entry.getValue()) - .append("<|im_end|>\n") - .append("\n"); - } - query.append("User:\n") - .append(prompt) - .append("<|im_end|>\n") - .append("\n") - .append("ChatGPT:\n"); - try { + .append("ChatGPT:\n"); - var request = HttpRequest.newBuilder() + var req = HttpRequest.newBuilder() .uri(URI.create("https://api.openai.com/v1/completions")) - .header("Authorization", "Bearer " + SettingsState.getInstance().secretKey) - .timeout(Duration.ofMinutes(1)) + .header("Accept", "text/event-stream") .header("Content-Type", "application/json") + .header("Authorization", "Bearer " + SettingsState.getInstance().secretKey) .POST(HttpRequest.BodyPublishers.ofString(objectMapper .writerWithDefaultPrettyPrinter() .writeValueAsString(Map.of( @@ -61,24 +53,29 @@ public final class ApiClient { "stop", List.of("<|im_end|>"), "prompt", query.toString(), "max_tokens", 400, - "temperature", 1.0 + "temperature", 1.0, + "stream", true )))) .build(); - client.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenAccept(response -> { - try { - var mappedResponse = objectMapper.readValue(response.body(), ApiResponse.class); - if (mappedResponse.getError() == null) { - queries.add(Map.entry(prompt, mappedResponse.getChoices().get(0).getText())); - onSuccess.accept(mappedResponse); - } else { - onError.accept(mappedResponse.getError()); - } - } catch (JsonProcessingException e) { - throw new RuntimeException(e); + this.client.sendAsync(req, respInfo -> + { + if (respInfo.statusCode() == 200) { + return new Subscriber((messageData -> + onMessage.accept(messageData.getChoices().get(0).getText())), + (finalMsg) -> queries.add(Map.entry(prompt, finalMsg))); + } else if (respInfo.statusCode() == 401) { + onMessage.accept("Incorrect API key provided.\n" + + "You can find your API key at https://platform.openai.com/account/api-keys."); + throw new IllegalArgumentException(); + } else { + onMessage.accept("Something went wrong. Please try again later."); + clearQueries(); + throw new RuntimeException(); } }); - } catch (IOException e) { + } catch (JsonProcessingException e) { + onMessage.accept("Something went wrong. Please try again later."); throw new RuntimeException(e); } } diff --git a/src/main/java/ee/carlrobert/chatgpt/client/Subscriber.java b/src/main/java/ee/carlrobert/chatgpt/client/Subscriber.java new file mode 100644 index 00000000..d1794edb --- /dev/null +++ b/src/main/java/ee/carlrobert/chatgpt/client/Subscriber.java @@ -0,0 +1,113 @@ +package ee.carlrobert.chatgpt.client; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.fasterxml.jackson.databind.ObjectMapper; +import ee.carlrobert.chatgpt.client.response.ApiResponse; +import java.net.http.HttpResponse; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Flow; +import java.util.function.Consumer; +import java.util.regex.Pattern; + +public class Subscriber implements HttpResponse.BodySubscriber { + + protected static final Pattern dataLinePattern = Pattern.compile("^data: ?(.*)$"); + + protected static ApiResponse extractMessageData(String[] messageLines) { + var responseBuilder = new StringBuilder(); + for (var line : messageLines) { + var matcher = dataLinePattern.matcher(line); + if (matcher.matches()) { + responseBuilder.append(matcher.group(1)); + } + } + + try { + return new ObjectMapper().readValue(responseBuilder.toString(), ApiResponse.class); + } catch (Exception e) { + throw new RuntimeException("Couldn't read the payload", e); + } + } + + protected final Consumer messageDataConsumer; + protected final CompletableFuture future; + protected volatile Flow.Subscription subscription; + protected volatile String deferredText; + private final Consumer onComplete; + private final StringBuilder msgBuilder = new StringBuilder(); + + public Subscriber(Consumer messageDataConsumer, Consumer onComplete) { + this.messageDataConsumer = messageDataConsumer; + this.future = new CompletableFuture<>(); + this.subscription = null; + this.deferredText = null; + this.onComplete = onComplete; + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + try { + this.deferredText = ""; + this.subscription.request(1); + } catch (Exception e) { + this.future.completeExceptionally(e); + this.subscription.cancel(); + } + } + + @Override + public void onNext(List buffers) { + try { + var deferredText = this.deferredText; + + for (var buffer : buffers) { + var s = deferredText + UTF_8.decode(buffer); + var tokens = s.split("\n\n", -1); + + for (var i = 0; i < tokens.length - 1; i++) { + var message = tokens[i]; + var data = extractMessageData(message.split("\n")); + var choice = data.getChoices().get(0); // TODO: Is there only one choice per response? + if ("stop".equals(choice.getFinish_reason())) { + onComplete(); + } else { + msgBuilder.append(choice.getText()); + } + this.messageDataConsumer.accept(data); + } + deferredText = tokens[tokens.length - 1]; + } + + this.deferredText = deferredText; + this.subscription.request(1); + } catch (Exception e) { + this.future.completeExceptionally(e); + this.subscription.cancel(); + } + } + + @Override + public void onError(Throwable e) { + this.future.completeExceptionally(e); + } + + @Override + public void onComplete() { + try { + this.future.complete(null); + this.onComplete.accept(msgBuilder.toString()); + } catch (Exception e) { + this.future.completeExceptionally(e); + } + } + + @Override + public CompletionStage getBody() { + return this.future; + } +} diff --git a/src/main/java/ee/carlrobert/chatgpt/toolwindow/ChatGptToolWindow.java b/src/main/java/ee/carlrobert/chatgpt/toolwindow/ChatGptToolWindow.java index fe660b13..d704f9f6 100644 --- a/src/main/java/ee/carlrobert/chatgpt/toolwindow/ChatGptToolWindow.java +++ b/src/main/java/ee/carlrobert/chatgpt/toolwindow/ChatGptToolWindow.java @@ -26,7 +26,8 @@ public class ChatGptToolWindow { public void handleSubmit() { var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class); var searchText = textField.getText(); - toolWindowService.sendMessage(searchText, searchText, this::scrollToBottom); + toolWindowService.paintUserMessage(searchText); + toolWindowService.sendMessage(searchText, this::scrollToBottom); textField.setText(""); scrollToBottom(); } @@ -60,5 +61,6 @@ public class ChatGptToolWindow { var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class); toolWindowService.setScrollablePanel(scrollablePanel); + toolWindowService.paintLandingView(); } } diff --git a/src/main/java/ee/carlrobert/chatgpt/toolwindow/ToolWindowService.java b/src/main/java/ee/carlrobert/chatgpt/toolwindow/ToolWindowService.java index 92e31c76..1274a95f 100644 --- a/src/main/java/ee/carlrobert/chatgpt/toolwindow/ToolWindowService.java +++ b/src/main/java/ee/carlrobert/chatgpt/toolwindow/ToolWindowService.java @@ -5,16 +5,22 @@ import static ee.carlrobert.chatgpt.toolwindow.ToolWindowUtil.createTextArea; import static ee.carlrobert.chatgpt.toolwindow.ToolWindowUtil.justifyLeft; import com.intellij.openapi.roots.ui.componentsList.components.ScrollablePanel; -import ee.carlrobert.chatgpt.client.ApiClient; import ee.carlrobert.chatgpt.EmptyCallback; -import ee.carlrobert.chatgpt.toolwindow.components.Loader; +import ee.carlrobert.chatgpt.client.ApiClient; +import java.awt.GridBagLayout; +import java.util.List; import java.util.Objects; import javax.annotation.Nullable; import javax.swing.Box; +import javax.swing.ImageIcon; +import javax.swing.JLabel; +import javax.swing.JPanel; +import javax.swing.SwingConstants; public class ToolWindowService { private ScrollablePanel scrollablePanel; + private boolean isLandingViewVisible; public void setScrollablePanel(ScrollablePanel scrollablePanel) { this.scrollablePanel = scrollablePanel; @@ -24,38 +30,61 @@ public class ToolWindowService { return scrollablePanel; } - public void sendMessage(String userMessage, String prompt) { - sendMessage(userMessage, prompt, null); + public void paintUserMessage(String userMessage) { + if (isLandingViewVisible) { + removeAll(); + } + scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/user-icon.png")), "User:"))); + scrollablePanel.add(Box.createVerticalStrut(8)); + scrollablePanel.add(createTextArea(userMessage, true)); } - public void sendMessage(String userMessage, String prompt, @Nullable EmptyCallback onSuccess) { - scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/user-icon.png")), "User"))); - scrollablePanel.add(Box.createVerticalStrut(8)); - scrollablePanel.add(createTextArea(userMessage, true, true)); + public void sendMessage(String prompt, @Nullable EmptyCallback scrollToBottom) { scrollablePanel.add(Box.createVerticalStrut(16)); - scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/chatgpt-icon.png")), "ChatGPT"))); + scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/chatgpt-icon.png")), "ChatGPT:"))); + scrollablePanel.add(Box.createVerticalStrut(8)); - var loader = new Loader(); - scrollablePanel.add(justifyLeft(loader.getComponent())); - loader.startLoading(); - scrollablePanel.add(Box.createVerticalStrut(4)); + var textArea = createTextArea("", false); + scrollablePanel.add(textArea); - ApiClient.getInstance().getCompletionsAsync(prompt, response -> { - loader.stopLoading(); - scrollablePanel.add(Box.createVerticalStrut(4)); - for (var choice : response.getChoices()) { - scrollablePanel.add(createTextArea(choice.getText().trim(), false, true)); + ApiClient.getInstance().getCompletionsAsync(prompt, (message) -> { + textArea.append(message); + if (scrollToBottom != null) { + scrollToBottom.call(); } - scrollablePanel.add(Box.createVerticalStrut(32)); - - if (onSuccess != null) { - onSuccess.call(); - } - }, apiError -> { - loader.stopLoading(); - scrollablePanel.add(Box.createVerticalStrut(4)); - scrollablePanel.add(createTextArea(apiError.getMessage(), false, true)); - scrollablePanel.add(Box.createVerticalStrut(32)); }); + scrollablePanel.add(Box.createVerticalStrut(16)); + } + + public void paintLandingView() { + isLandingViewVisible = true; + + var imageIconPanel = new JPanel(); + imageIconPanel.setLayout(new GridBagLayout()); + var imageIconLabel = new JLabel(new ImageIcon(Objects.requireNonNull(getClass().getResource("/icons/sun-icon.png")))); + imageIconLabel.setHorizontalAlignment(JLabel.CENTER); + imageIconPanel.add(imageIconLabel); + scrollablePanel.add(imageIconPanel); + + scrollablePanel.add(Box.createVerticalStrut(16)); + + var questions = List.of("How do I make an HTTP request in Javascript?", + "What is the difference between px, dip, dp, and sp?", + "How do I undo the most recent local commits in Git?", + "What is the difference between stack and heap?"); + for (var question : questions) { + var panel = new JPanel(); + panel.setLayout(new GridBagLayout()); + var label = new JLabel(question, SwingConstants.CENTER); + label.setHorizontalAlignment(JLabel.CENTER); + panel.add(label); + scrollablePanel.add(panel); + scrollablePanel.add(Box.createVerticalStrut(16)); + } + } + + public void removeAll() { + isLandingViewVisible = false; + scrollablePanel.removeAll(); } } diff --git a/src/main/java/ee/carlrobert/chatgpt/toolwindow/ToolWindowUtil.java b/src/main/java/ee/carlrobert/chatgpt/toolwindow/ToolWindowUtil.java index 66506a73..c341e348 100644 --- a/src/main/java/ee/carlrobert/chatgpt/toolwindow/ToolWindowUtil.java +++ b/src/main/java/ee/carlrobert/chatgpt/toolwindow/ToolWindowUtil.java @@ -11,16 +11,14 @@ import javax.swing.JTextArea; public class ToolWindowUtil { - public static JTextArea createTextArea(String selectedText, boolean isItalicFont, boolean transparentBackground) { + public static JTextArea createTextArea(String selectedText, boolean isItalicFont) { var textArea = new JTextArea(); textArea.append(selectedText); textArea.setLineWrap(true); textArea.setEditable(false); - textArea.setFont(createFont(isItalicFont, textArea.getFont().getSize())); + textArea.setFont(new Font("Tahoma", isItalicFont ? Font.ITALIC : Font.PLAIN, textArea.getFont().getSize())); textArea.setWrapStyleWord(true); - if (transparentBackground) { - textArea.setBackground(JBColor.background()); - } + textArea.setBackground(JBColor.PanelBackground); // textArea.setBorder(new MatteBorder(0, 2, 0, 0, JBColor.RED)); return textArea; } diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index d20770ea..f9529525 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -31,6 +31,7 @@ +
  • 1.0.3 Server-Sent Events
  • 1.0.2 Visual changes
  • 1.0.1 Design updates and simple missing key notification
  • 1.0.0 First release
  • diff --git a/src/main/resources/icons/sun-icon.png b/src/main/resources/icons/sun-icon.png new file mode 100644 index 0000000000000000000000000000000000000000..d8a44b772cf3b90834063bdac4a7d6506353b1b1 GIT binary patch literal 830 zcmV-E1Ht@>P)y6ZUz_<0ouv>nMtqDLccL57;3JAg3ag=Qokw?W_jIHWGz>kr>elJ1 zu3J@YOpY8mMhcP-NIo)~9VFM~GMgRDiDV3!^9X>w0Jf(^q}gn4tyC&|jB64kKauQs zo>xqoi064ll8;G#s8*}HjVm-7jcovy#UKDrqL3lk2Jmp$^`7SyWe?qJSv(Wl@If1p zd=6m0(+&P6@hpiw5+{=yiC(UclK{dH9dSd~1sNgnIlL)=F#yK^ENXf3hU8oQeal*V zRMv!K_L2M|I#VMNdWsXld)C_H0G9kr)$5WqLN=st!i9G?43akh{8VE*04@WV0%oO6CWE)5%&n z0072GzSMAX8>N4J!b6-RBHw2iKEaRU-;<_pH@wyHv-=xOPM)60Txp zbGa|Zq)7x3s_CKNXDm7F;_3i!#Cd3K;AOvJVgkT5C344%3f9_-N@PV+X;gWU0q|6b zOq)?bx;K)q)R5Vz@{+byw(GnTj0ynf)hlaKb_^?Ltv#hOQr?_3Spdn3Pe#g)eqpAk z`COoATrgPyfUy9Xb&dOZ$w-aAZ`mf4Bf#hBJFKn=QjO6=NTiwv`VaY}Lt=Ufo za_I~ok{sEjF*i$X?&Kv+MqX+eD8w9@M*3pq{Yj=%uh(}NS16au2T6WUd5GNs@D0EL z;}QTS)BLP}-*dA`lNDs{Lr+}fM6!ne$s=v&J$8;9IZRai0u6~Cy<5WSX8-^I07*qo IM6N<$g2+N|Jpcdz literal 0 HcmV?d00001