Code refactoring, add readme

This commit is contained in:
Carl-Robert Linnupuu 2023-02-21 23:08:31 +00:00
parent 53e8ceeb06
commit d71edd19b3
4 changed files with 74 additions and 50 deletions

View file

@ -6,6 +6,8 @@ import ee.carlrobert.chatgpt.ide.settings.SettingsState;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse.BodySubscriber;
import java.net.http.HttpResponse.ResponseInfo;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -21,27 +23,44 @@ public final class ApiClient {
private ApiClient() {
}
public void getCompletionsAsync(String prompt, Consumer<String> onMessage) {
try {
var query = new StringBuilder(
"You are ChatGPT, a large language model trained by OpenAI.\n");
for (var entry : queries) {
query.append("User:\n")
.append(entry.getKey())
public static ApiClient getInstance() {
if (instance == null) {
instance = new ApiClient();
}
return instance;
}
public void getCompletionsAsync(String userPrompt, Consumer<String> onMessage) {
var prompt = buildCompletePrompt(userPrompt);
this.client.sendAsync(buildHttpRequest(prompt), respInfo -> subscribe(respInfo, userPrompt, onMessage));
}
public void clearQueries() {
queries.clear();
}
private String buildCompletePrompt(String prompt) {
var basePrompt = new StringBuilder("You are ChatGPT, a large language model trained by OpenAI.\n");
queries.forEach(query ->
basePrompt.append("User:\n")
.append(query.getKey())
.append("<|im_end|>\n")
.append("\n")
.append("ChatGPT:\n")
.append(entry.getValue())
.append(query.getValue())
.append("<|im_end|>\n")
.append("\n");
}
query.append("User:\n")
.append(prompt)
.append("<|im_end|>\n")
.append("\n")
.append("ChatGPT:\n");
.append("\n"));
basePrompt.append("User:\n")
.append(prompt)
.append("<|im_end|>\n")
.append("\n")
.append("ChatGPT:\n");
return basePrompt.toString();
}
var req = HttpRequest.newBuilder()
private HttpRequest buildHttpRequest(String prompt) {
try {
return HttpRequest.newBuilder()
.uri(URI.create("https://api.openai.com/v1/completions"))
.header("Accept", "text/event-stream")
.header("Content-Type", "application/json")
@ -51,46 +70,33 @@ public final class ApiClient {
.writeValueAsString(Map.of(
"model", "text-davinci-003",
"stop", List.of("<|im_end|>"),
"prompt", query.toString(),
"prompt", prompt,
"max_tokens", 1000,
"temperature", 1.0,
"stream", true
))))
.build();
this.client.sendAsync(req, respInfo ->
{
if (respInfo.statusCode() == 200) {
return new Subscriber((messageData ->
onMessage.accept(messageData.getChoices().get(0).getText())),
(finalMsg) -> queries.add(Map.entry(prompt, finalMsg)));
} else if (respInfo.statusCode() == 401) {
onMessage.accept("Incorrect API key provided.\n" +
"You can find your API key at https://platform.openai.com/account/api-keys.");
throw new IllegalArgumentException();
} else if (respInfo.statusCode() == 429) {
onMessage.accept("You exceeded your current quota, please check your plan and billing details.");
throw new RuntimeException("Insufficient quota");
} else {
onMessage.accept("Something went wrong. Please try again later.");
clearQueries();
throw new RuntimeException();
}
});
} catch (JsonProcessingException e) {
throw new RuntimeException("Unable to serialize request payload", e);
}
}
private BodySubscriber<Void> subscribe(ResponseInfo responseInfo, String userPrompt, Consumer<String> onMessage) {
if (responseInfo.statusCode() == 200) {
return new Subscriber((messageData ->
onMessage.accept(messageData.getChoices().get(0).getText())),
(finalMsg) -> queries.add(Map.entry(userPrompt, finalMsg)));
} else if (responseInfo.statusCode() == 401) {
onMessage.accept("Incorrect API key provided.\n" +
"You can find your API key at https://platform.openai.com/account/api-keys.");
throw new IllegalArgumentException();
} else if (responseInfo.statusCode() == 429) {
onMessage.accept("You exceeded your current quota, please check your plan and billing details.");
throw new RuntimeException("Insufficient quota");
} else {
onMessage.accept("Something went wrong. Please try again later.");
throw new RuntimeException(e);
clearQueries();
throw new RuntimeException();
}
}
public static ApiClient getInstance() {
if (instance == null) {
instance = new ApiClient();
}
return instance;
}
public void clearQueries() {
queries.clear();
}
}

View file

@ -74,7 +74,7 @@ public class Subscriber implements HttpResponse.BodySubscriber<Void> {
var data = extractMessageData(message.split("\n"));
var choice = data.getChoices().get(0);
if ("stop".equals(choice.getFinish_reason())) {
onComplete();
// onComplete();
} else {
msgBuilder.append(choice.getText());
}