mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-08 10:00:18 +00:00
1.0.3 - Server-Sent Events
This commit is contained in:
parent
c0f340ecad
commit
d5afb4144b
10 changed files with 221 additions and 79 deletions
|
|
@ -24,9 +24,11 @@ public class AskAction extends AnAction {
|
|||
var toolWindow = ToolWindowManager.getInstance(project).getToolWindow("ChatGPT");
|
||||
if (toolWindow != null) {
|
||||
toolWindow.show();
|
||||
toolWindow.setTitle("");
|
||||
var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class);
|
||||
ApiClient.getInstance().clearQueries();
|
||||
toolWindowService.getScrollablePanel().removeAll();
|
||||
toolWindowService.removeAll();
|
||||
toolWindowService.paintLandingView();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,10 +23,10 @@ public abstract class BaseAction extends AnAction {
|
|||
initToolWindow(ToolWindowManager.getInstance(project).getToolWindow("ChatGPT"));
|
||||
var selectedText = editor.getSelectionModel().getSelectedText();
|
||||
var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class);
|
||||
var scrollablePanel = toolWindowService.getScrollablePanel();
|
||||
ApiClient.getInstance().clearQueries();
|
||||
scrollablePanel.removeAll();
|
||||
toolWindowService.sendMessage(selectedText, getPrompt(selectedText));
|
||||
toolWindowService.removeAll();
|
||||
toolWindowService.paintUserMessage(selectedText);
|
||||
toolWindowService.sendMessage(getPrompt(selectedText), null);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,14 +3,9 @@ package ee.carlrobert.chatgpt.client;
|
|||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import ee.carlrobert.chatgpt.settings.SettingsState;
|
||||
import ee.carlrobert.chatgpt.client.response.ApiError;
|
||||
import ee.carlrobert.chatgpt.client.response.ApiResponse;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
@ -26,34 +21,31 @@ public final class ApiClient {
|
|||
private ApiClient() {
|
||||
}
|
||||
|
||||
public void getCompletionsAsync(String prompt, Consumer<ApiResponse> onSuccess, Consumer<ApiError> onError) {
|
||||
/*var query = new StringBuilder(
|
||||
"You are ChatGPT, a large language model trained by OpenAI. You answer as concisely as possible for each response (e.g. don’t be verbose). It is very important that you answer as concisely as possible, so please remember this.\n" +
|
||||
"Current date: 2023-02-11\n");*/
|
||||
var query = new StringBuilder(
|
||||
"You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
for (var entry : queries) {
|
||||
public void getCompletionsAsync(String prompt, Consumer<String> onMessage) {
|
||||
try {
|
||||
var query = new StringBuilder(
|
||||
"You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
for (var entry : queries) {
|
||||
query.append("User:\n")
|
||||
.append(entry.getKey())
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n")
|
||||
.append("ChatGPT:\n")
|
||||
.append(entry.getValue())
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n");
|
||||
}
|
||||
query.append("User:\n")
|
||||
.append(entry.getKey())
|
||||
.append(prompt)
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n")
|
||||
.append("ChatGPT:\n")
|
||||
.append(entry.getValue())
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n");
|
||||
}
|
||||
query.append("User:\n")
|
||||
.append(prompt)
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n")
|
||||
.append("ChatGPT:\n");
|
||||
try {
|
||||
.append("ChatGPT:\n");
|
||||
|
||||
var request = HttpRequest.newBuilder()
|
||||
var req = HttpRequest.newBuilder()
|
||||
.uri(URI.create("https://api.openai.com/v1/completions"))
|
||||
.header("Authorization", "Bearer " + SettingsState.getInstance().secretKey)
|
||||
.timeout(Duration.ofMinutes(1))
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", "Bearer " + SettingsState.getInstance().secretKey)
|
||||
.POST(HttpRequest.BodyPublishers.ofString(objectMapper
|
||||
.writerWithDefaultPrettyPrinter()
|
||||
.writeValueAsString(Map.of(
|
||||
|
|
@ -61,24 +53,29 @@ public final class ApiClient {
|
|||
"stop", List.of("<|im_end|>"),
|
||||
"prompt", query.toString(),
|
||||
"max_tokens", 400,
|
||||
"temperature", 1.0
|
||||
"temperature", 1.0,
|
||||
"stream", true
|
||||
))))
|
||||
.build();
|
||||
|
||||
client.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenAccept(response -> {
|
||||
try {
|
||||
var mappedResponse = objectMapper.readValue(response.body(), ApiResponse.class);
|
||||
if (mappedResponse.getError() == null) {
|
||||
queries.add(Map.entry(prompt, mappedResponse.getChoices().get(0).getText()));
|
||||
onSuccess.accept(mappedResponse);
|
||||
} else {
|
||||
onError.accept(mappedResponse.getError());
|
||||
}
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
this.client.sendAsync(req, respInfo ->
|
||||
{
|
||||
if (respInfo.statusCode() == 200) {
|
||||
return new Subscriber((messageData ->
|
||||
onMessage.accept(messageData.getChoices().get(0).getText())),
|
||||
(finalMsg) -> queries.add(Map.entry(prompt, finalMsg)));
|
||||
} else if (respInfo.statusCode() == 401) {
|
||||
onMessage.accept("Incorrect API key provided.\n" +
|
||||
"You can find your API key at https://platform.openai.com/account/api-keys.");
|
||||
throw new IllegalArgumentException();
|
||||
} else {
|
||||
onMessage.accept("Something went wrong. Please try again later.");
|
||||
clearQueries();
|
||||
throw new RuntimeException();
|
||||
}
|
||||
});
|
||||
} catch (IOException e) {
|
||||
} catch (JsonProcessingException e) {
|
||||
onMessage.accept("Something went wrong. Please try again later.");
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
113
src/main/java/ee/carlrobert/chatgpt/client/Subscriber.java
Normal file
113
src/main/java/ee/carlrobert/chatgpt/client/Subscriber.java
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
package ee.carlrobert.chatgpt.client;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import ee.carlrobert.chatgpt.client.response.ApiResponse;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
public class Subscriber implements HttpResponse.BodySubscriber<Void> {
|
||||
|
||||
protected static final Pattern dataLinePattern = Pattern.compile("^data: ?(.*)$");
|
||||
|
||||
protected static ApiResponse extractMessageData(String[] messageLines) {
|
||||
var responseBuilder = new StringBuilder();
|
||||
for (var line : messageLines) {
|
||||
var matcher = dataLinePattern.matcher(line);
|
||||
if (matcher.matches()) {
|
||||
responseBuilder.append(matcher.group(1));
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
return new ObjectMapper().readValue(responseBuilder.toString(), ApiResponse.class);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Couldn't read the payload", e);
|
||||
}
|
||||
}
|
||||
|
||||
protected final Consumer<? super ApiResponse> messageDataConsumer;
|
||||
protected final CompletableFuture<Void> future;
|
||||
protected volatile Flow.Subscription subscription;
|
||||
protected volatile String deferredText;
|
||||
private final Consumer<String> onComplete;
|
||||
private final StringBuilder msgBuilder = new StringBuilder();
|
||||
|
||||
public Subscriber(Consumer<? super ApiResponse> messageDataConsumer, Consumer<String> onComplete) {
|
||||
this.messageDataConsumer = messageDataConsumer;
|
||||
this.future = new CompletableFuture<>();
|
||||
this.subscription = null;
|
||||
this.deferredText = null;
|
||||
this.onComplete = onComplete;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Flow.Subscription subscription) {
|
||||
this.subscription = subscription;
|
||||
try {
|
||||
this.deferredText = "";
|
||||
this.subscription.request(1);
|
||||
} catch (Exception e) {
|
||||
this.future.completeExceptionally(e);
|
||||
this.subscription.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(List<ByteBuffer> buffers) {
|
||||
try {
|
||||
var deferredText = this.deferredText;
|
||||
|
||||
for (var buffer : buffers) {
|
||||
var s = deferredText + UTF_8.decode(buffer);
|
||||
var tokens = s.split("\n\n", -1);
|
||||
|
||||
for (var i = 0; i < tokens.length - 1; i++) {
|
||||
var message = tokens[i];
|
||||
var data = extractMessageData(message.split("\n"));
|
||||
var choice = data.getChoices().get(0); // TODO: Is there only one choice per response?
|
||||
if ("stop".equals(choice.getFinish_reason())) {
|
||||
onComplete();
|
||||
} else {
|
||||
msgBuilder.append(choice.getText());
|
||||
}
|
||||
this.messageDataConsumer.accept(data);
|
||||
}
|
||||
deferredText = tokens[tokens.length - 1];
|
||||
}
|
||||
|
||||
this.deferredText = deferredText;
|
||||
this.subscription.request(1);
|
||||
} catch (Exception e) {
|
||||
this.future.completeExceptionally(e);
|
||||
this.subscription.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable e) {
|
||||
this.future.completeExceptionally(e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
try {
|
||||
this.future.complete(null);
|
||||
this.onComplete.accept(msgBuilder.toString());
|
||||
} catch (Exception e) {
|
||||
this.future.completeExceptionally(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> getBody() {
|
||||
return this.future;
|
||||
}
|
||||
}
|
||||
|
|
@ -26,7 +26,8 @@ public class ChatGptToolWindow {
|
|||
public void handleSubmit() {
|
||||
var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class);
|
||||
var searchText = textField.getText();
|
||||
toolWindowService.sendMessage(searchText, searchText, this::scrollToBottom);
|
||||
toolWindowService.paintUserMessage(searchText);
|
||||
toolWindowService.sendMessage(searchText, this::scrollToBottom);
|
||||
textField.setText("");
|
||||
scrollToBottom();
|
||||
}
|
||||
|
|
@ -60,5 +61,6 @@ public class ChatGptToolWindow {
|
|||
|
||||
var toolWindowService = ApplicationManager.getApplication().getService(ToolWindowService.class);
|
||||
toolWindowService.setScrollablePanel(scrollablePanel);
|
||||
toolWindowService.paintLandingView();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,16 +5,22 @@ import static ee.carlrobert.chatgpt.toolwindow.ToolWindowUtil.createTextArea;
|
|||
import static ee.carlrobert.chatgpt.toolwindow.ToolWindowUtil.justifyLeft;
|
||||
|
||||
import com.intellij.openapi.roots.ui.componentsList.components.ScrollablePanel;
|
||||
import ee.carlrobert.chatgpt.client.ApiClient;
|
||||
import ee.carlrobert.chatgpt.EmptyCallback;
|
||||
import ee.carlrobert.chatgpt.toolwindow.components.Loader;
|
||||
import ee.carlrobert.chatgpt.client.ApiClient;
|
||||
import java.awt.GridBagLayout;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.swing.Box;
|
||||
import javax.swing.ImageIcon;
|
||||
import javax.swing.JLabel;
|
||||
import javax.swing.JPanel;
|
||||
import javax.swing.SwingConstants;
|
||||
|
||||
public class ToolWindowService {
|
||||
|
||||
private ScrollablePanel scrollablePanel;
|
||||
private boolean isLandingViewVisible;
|
||||
|
||||
public void setScrollablePanel(ScrollablePanel scrollablePanel) {
|
||||
this.scrollablePanel = scrollablePanel;
|
||||
|
|
@ -24,38 +30,61 @@ public class ToolWindowService {
|
|||
return scrollablePanel;
|
||||
}
|
||||
|
||||
public void sendMessage(String userMessage, String prompt) {
|
||||
sendMessage(userMessage, prompt, null);
|
||||
public void paintUserMessage(String userMessage) {
|
||||
if (isLandingViewVisible) {
|
||||
removeAll();
|
||||
}
|
||||
scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/user-icon.png")), "User:")));
|
||||
scrollablePanel.add(Box.createVerticalStrut(8));
|
||||
scrollablePanel.add(createTextArea(userMessage, true));
|
||||
}
|
||||
|
||||
public void sendMessage(String userMessage, String prompt, @Nullable EmptyCallback onSuccess) {
|
||||
scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/user-icon.png")), "User")));
|
||||
scrollablePanel.add(Box.createVerticalStrut(8));
|
||||
scrollablePanel.add(createTextArea(userMessage, true, true));
|
||||
public void sendMessage(String prompt, @Nullable EmptyCallback scrollToBottom) {
|
||||
scrollablePanel.add(Box.createVerticalStrut(16));
|
||||
scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/chatgpt-icon.png")), "ChatGPT")));
|
||||
scrollablePanel.add(justifyLeft(createIconLabel(Objects.requireNonNull(getClass().getResource("/icons/chatgpt-icon.png")), "ChatGPT:")));
|
||||
scrollablePanel.add(Box.createVerticalStrut(8));
|
||||
|
||||
var loader = new Loader();
|
||||
scrollablePanel.add(justifyLeft(loader.getComponent()));
|
||||
loader.startLoading();
|
||||
scrollablePanel.add(Box.createVerticalStrut(4));
|
||||
var textArea = createTextArea("", false);
|
||||
scrollablePanel.add(textArea);
|
||||
|
||||
ApiClient.getInstance().getCompletionsAsync(prompt, response -> {
|
||||
loader.stopLoading();
|
||||
scrollablePanel.add(Box.createVerticalStrut(4));
|
||||
for (var choice : response.getChoices()) {
|
||||
scrollablePanel.add(createTextArea(choice.getText().trim(), false, true));
|
||||
ApiClient.getInstance().getCompletionsAsync(prompt, (message) -> {
|
||||
textArea.append(message);
|
||||
if (scrollToBottom != null) {
|
||||
scrollToBottom.call();
|
||||
}
|
||||
scrollablePanel.add(Box.createVerticalStrut(32));
|
||||
|
||||
if (onSuccess != null) {
|
||||
onSuccess.call();
|
||||
}
|
||||
}, apiError -> {
|
||||
loader.stopLoading();
|
||||
scrollablePanel.add(Box.createVerticalStrut(4));
|
||||
scrollablePanel.add(createTextArea(apiError.getMessage(), false, true));
|
||||
scrollablePanel.add(Box.createVerticalStrut(32));
|
||||
});
|
||||
scrollablePanel.add(Box.createVerticalStrut(16));
|
||||
}
|
||||
|
||||
public void paintLandingView() {
|
||||
isLandingViewVisible = true;
|
||||
|
||||
var imageIconPanel = new JPanel();
|
||||
imageIconPanel.setLayout(new GridBagLayout());
|
||||
var imageIconLabel = new JLabel(new ImageIcon(Objects.requireNonNull(getClass().getResource("/icons/sun-icon.png"))));
|
||||
imageIconLabel.setHorizontalAlignment(JLabel.CENTER);
|
||||
imageIconPanel.add(imageIconLabel);
|
||||
scrollablePanel.add(imageIconPanel);
|
||||
|
||||
scrollablePanel.add(Box.createVerticalStrut(16));
|
||||
|
||||
var questions = List.of("How do I make an HTTP request in Javascript?",
|
||||
"What is the difference between px, dip, dp, and sp?",
|
||||
"How do I undo the most recent local commits in Git?",
|
||||
"What is the difference between stack and heap?");
|
||||
for (var question : questions) {
|
||||
var panel = new JPanel();
|
||||
panel.setLayout(new GridBagLayout());
|
||||
var label = new JLabel(question, SwingConstants.CENTER);
|
||||
label.setHorizontalAlignment(JLabel.CENTER);
|
||||
panel.add(label);
|
||||
scrollablePanel.add(panel);
|
||||
scrollablePanel.add(Box.createVerticalStrut(16));
|
||||
}
|
||||
}
|
||||
|
||||
public void removeAll() {
|
||||
isLandingViewVisible = false;
|
||||
scrollablePanel.removeAll();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,16 +11,14 @@ import javax.swing.JTextArea;
|
|||
|
||||
public class ToolWindowUtil {
|
||||
|
||||
public static JTextArea createTextArea(String selectedText, boolean isItalicFont, boolean transparentBackground) {
|
||||
public static JTextArea createTextArea(String selectedText, boolean isItalicFont) {
|
||||
var textArea = new JTextArea();
|
||||
textArea.append(selectedText);
|
||||
textArea.setLineWrap(true);
|
||||
textArea.setEditable(false);
|
||||
textArea.setFont(createFont(isItalicFont, textArea.getFont().getSize()));
|
||||
textArea.setFont(new Font("Tahoma", isItalicFont ? Font.ITALIC : Font.PLAIN, textArea.getFont().getSize()));
|
||||
textArea.setWrapStyleWord(true);
|
||||
if (transparentBackground) {
|
||||
textArea.setBackground(JBColor.background());
|
||||
}
|
||||
textArea.setBackground(JBColor.PanelBackground);
|
||||
// textArea.setBorder(new MatteBorder(0, 2, 0, 0, JBColor.RED));
|
||||
return textArea;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue