mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-06 08:02:13 +00:00
Code refactoring, add readme
This commit is contained in:
parent
53e8ceeb06
commit
d71edd19b3
4 changed files with 74 additions and 50 deletions
|
|
@ -6,6 +6,8 @@ import ee.carlrobert.chatgpt.ide.settings.SettingsState;
|
|||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse.BodySubscriber;
|
||||
import java.net.http.HttpResponse.ResponseInfo;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
|
@ -21,27 +23,44 @@ public final class ApiClient {
|
|||
private ApiClient() {
|
||||
}
|
||||
|
||||
public void getCompletionsAsync(String prompt, Consumer<String> onMessage) {
|
||||
try {
|
||||
var query = new StringBuilder(
|
||||
"You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
for (var entry : queries) {
|
||||
query.append("User:\n")
|
||||
.append(entry.getKey())
|
||||
public static ApiClient getInstance() {
|
||||
if (instance == null) {
|
||||
instance = new ApiClient();
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
public void getCompletionsAsync(String userPrompt, Consumer<String> onMessage) {
|
||||
var prompt = buildCompletePrompt(userPrompt);
|
||||
this.client.sendAsync(buildHttpRequest(prompt), respInfo -> subscribe(respInfo, userPrompt, onMessage));
|
||||
}
|
||||
|
||||
public void clearQueries() {
|
||||
queries.clear();
|
||||
}
|
||||
|
||||
private String buildCompletePrompt(String prompt) {
|
||||
var basePrompt = new StringBuilder("You are ChatGPT, a large language model trained by OpenAI.\n");
|
||||
queries.forEach(query ->
|
||||
basePrompt.append("User:\n")
|
||||
.append(query.getKey())
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n")
|
||||
.append("ChatGPT:\n")
|
||||
.append(entry.getValue())
|
||||
.append(query.getValue())
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n");
|
||||
}
|
||||
query.append("User:\n")
|
||||
.append(prompt)
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n")
|
||||
.append("ChatGPT:\n");
|
||||
.append("\n"));
|
||||
basePrompt.append("User:\n")
|
||||
.append(prompt)
|
||||
.append("<|im_end|>\n")
|
||||
.append("\n")
|
||||
.append("ChatGPT:\n");
|
||||
return basePrompt.toString();
|
||||
}
|
||||
|
||||
var req = HttpRequest.newBuilder()
|
||||
private HttpRequest buildHttpRequest(String prompt) {
|
||||
try {
|
||||
return HttpRequest.newBuilder()
|
||||
.uri(URI.create("https://api.openai.com/v1/completions"))
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("Content-Type", "application/json")
|
||||
|
|
@ -51,46 +70,33 @@ public final class ApiClient {
|
|||
.writeValueAsString(Map.of(
|
||||
"model", "text-davinci-003",
|
||||
"stop", List.of("<|im_end|>"),
|
||||
"prompt", query.toString(),
|
||||
"prompt", prompt,
|
||||
"max_tokens", 1000,
|
||||
"temperature", 1.0,
|
||||
"stream", true
|
||||
))))
|
||||
.build();
|
||||
|
||||
this.client.sendAsync(req, respInfo ->
|
||||
{
|
||||
if (respInfo.statusCode() == 200) {
|
||||
return new Subscriber((messageData ->
|
||||
onMessage.accept(messageData.getChoices().get(0).getText())),
|
||||
(finalMsg) -> queries.add(Map.entry(prompt, finalMsg)));
|
||||
} else if (respInfo.statusCode() == 401) {
|
||||
onMessage.accept("Incorrect API key provided.\n" +
|
||||
"You can find your API key at https://platform.openai.com/account/api-keys.");
|
||||
throw new IllegalArgumentException();
|
||||
} else if (respInfo.statusCode() == 429) {
|
||||
onMessage.accept("You exceeded your current quota, please check your plan and billing details.");
|
||||
throw new RuntimeException("Insufficient quota");
|
||||
} else {
|
||||
onMessage.accept("Something went wrong. Please try again later.");
|
||||
clearQueries();
|
||||
throw new RuntimeException();
|
||||
}
|
||||
});
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException("Unable to serialize request payload", e);
|
||||
}
|
||||
}
|
||||
|
||||
private BodySubscriber<Void> subscribe(ResponseInfo responseInfo, String userPrompt, Consumer<String> onMessage) {
|
||||
if (responseInfo.statusCode() == 200) {
|
||||
return new Subscriber((messageData ->
|
||||
onMessage.accept(messageData.getChoices().get(0).getText())),
|
||||
(finalMsg) -> queries.add(Map.entry(userPrompt, finalMsg)));
|
||||
} else if (responseInfo.statusCode() == 401) {
|
||||
onMessage.accept("Incorrect API key provided.\n" +
|
||||
"You can find your API key at https://platform.openai.com/account/api-keys.");
|
||||
throw new IllegalArgumentException();
|
||||
} else if (responseInfo.statusCode() == 429) {
|
||||
onMessage.accept("You exceeded your current quota, please check your plan and billing details.");
|
||||
throw new RuntimeException("Insufficient quota");
|
||||
} else {
|
||||
onMessage.accept("Something went wrong. Please try again later.");
|
||||
throw new RuntimeException(e);
|
||||
clearQueries();
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public static ApiClient getInstance() {
|
||||
if (instance == null) {
|
||||
instance = new ApiClient();
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
public void clearQueries() {
|
||||
queries.clear();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ public class Subscriber implements HttpResponse.BodySubscriber<Void> {
|
|||
var data = extractMessageData(message.split("\n"));
|
||||
var choice = data.getChoices().get(0);
|
||||
if ("stop".equals(choice.getFinish_reason())) {
|
||||
onComplete();
|
||||
// onComplete();
|
||||
} else {
|
||||
msgBuilder.append(choice.getText());
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue