feat: support lookup completions for custom openai service

This commit is contained in:
Carl-Robert Linnupuu 2024-02-24 14:38:51 +02:00
parent 557f9b0ca0
commit eeda43b0e4
3 changed files with 82 additions and 36 deletions

View file

@ -1,2 +1 @@
Given an existing function or method body, generate five alternative names for the function.
The response must be a comma-separated list of names. Exclude any additional information.
Provide five alternative names for a given function or method body. Your response should be a list of names, separated by commas, without any extra information.

View file

@ -21,6 +21,7 @@ import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.IncludedFilesSettings;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettingsState;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
@ -88,8 +89,7 @@ public class CompletionRequestProvider {
.replace("{QUESTION}", userPrompt);
}
public static OpenAIChatCompletionRequest buildOpenAILookupCompletionRequest(
String context) {
public static OpenAIChatCompletionRequest buildOpenAILookupCompletionRequest(String context) {
return new OpenAIChatCompletionRequest.Builder(
List.of(
new OpenAIChatCompletionMessage("system",
@ -100,6 +100,17 @@ public class CompletionRequestProvider {
.build();
}
public static Request buildCustomOpenAILookupCompletionRequest(String context) {
return buildCustomOpenAIChatCompletionRequest(
CustomServiceSettings.getCurrentState(),
List.of(
new OpenAIChatCompletionMessage(
"system",
getResourceContent("/prompts/method-name-generator.txt")),
new OpenAIChatCompletionMessage("user", context)),
false);
}
public static LlamaCompletionRequest buildLlamaLookupCompletionRequest(String context) {
return new LlamaCompletionRequest.Builder(PromptTemplate.LLAMA
.buildPrompt(getResourceContent("/prompts/method-name-generator.txt"), context, List.of()))
@ -175,6 +186,51 @@ public class CompletionRequestProvider {
return builder.build();
}
public Request buildCustomOpenAIChatCompletionRequest(
CustomServiceSettingsState customConfiguration,
CallParameters callParameters) {
return buildCustomOpenAIChatCompletionRequest(
customConfiguration,
buildMessages(callParameters, false),
true);
}
private static Request buildCustomOpenAIChatCompletionRequest(
CustomServiceSettingsState customConfiguration,
List<OpenAIChatCompletionMessage> messages,
boolean streamRequest) {
var requestBuilder = new Request.Builder().url(customConfiguration.getUrl().trim());
for (var entry : customConfiguration.getHeaders().entrySet()) {
String value = entry.getValue();
if (value.contains("$CUSTOM_SERVICE_API_KEY")) {
value = value.replace("$CUSTOM_SERVICE_API_KEY",
CustomServiceCredentialManager.getInstance().getCredential());
}
requestBuilder.addHeader(entry.getKey(), value);
}
var body = customConfiguration.getBody().entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> {
if (!streamRequest && "stream".equals(entry.getKey())) {
return false;
}
return processEntryValue(entry.getValue(), messages);
}
));
try {
var requestBody = RequestBody.create(new ObjectMapper()
.writerWithDefaultPrettyPrinter()
.writeValueAsString(body)
.getBytes(StandardCharsets.UTF_8));
return requestBuilder.post(requestBody).build();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
public List<OpenAIChatCompletionMessage> buildMessages(
CallParameters callParameters,
boolean useContextualSearch) {
@ -257,37 +313,9 @@ public class CompletionRequestProvider {
return messages.stream().filter(Objects::nonNull).collect(toList());
}
public Request buildCustomOpenAIChatCompletionRequest(
CustomServiceSettingsState customConfiguration,
CallParameters callParameters) {
var requestBuilder = new Request.Builder().url(customConfiguration.getUrl().trim());
for (var entry : customConfiguration.getHeaders().entrySet()) {
String value = entry.getValue();
if (value.contains("$CUSTOM_SERVICE_API_KEY")) {
value = value.replace("$CUSTOM_SERVICE_API_KEY",
CustomServiceCredentialManager.getInstance().getCredential());
}
requestBuilder.addHeader(entry.getKey(), value);
}
var body = customConfiguration.getBody().entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> processEntryValue(entry.getValue(), callParameters)
));
try {
var requestBody = RequestBody.create(new ObjectMapper()
.writerWithDefaultPrettyPrinter()
.writeValueAsString(body)
.getBytes(StandardCharsets.UTF_8));
return requestBuilder.post(requestBody).build();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
private Object processEntryValue(Object value, CallParameters callParameters) {
private static Object processEntryValue(
Object value,
List<OpenAIChatCompletionMessage> messages) {
if (!(value instanceof String)) {
return value;
}
@ -295,7 +323,7 @@ public class CompletionRequestProvider {
String stringValue = (String) value;
switch (stringValue.toLowerCase().trim()) {
case "$openai_messages":
return buildMessages(callParameters, false);
return messages;
case "true":
case "false":
return Boolean.parseBoolean(stringValue);

View file

@ -1,6 +1,7 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;
@ -19,11 +20,14 @@ import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponse;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import okhttp3.Request;
@ -160,10 +164,25 @@ public final class CompletionRequestService {
return Optional.empty();
}
if (selectedService == CUSTOM_OPENAI) {
var request = CompletionRequestProvider.buildCustomOpenAILookupCompletionRequest(prompt);
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
try (var response = httpClient.newCall(request).execute()) {
return tryExtractContent(
DeserializationUtil.mapResponse(response, OpenAIChatCompletionResponse.class));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
var request = CompletionRequestProvider.buildOpenAILookupCompletionRequest(prompt);
var response = selectedService == OPENAI
? CompletionClientProvider.getOpenAIClient().getChatCompletion(request)
: CompletionClientProvider.getAzureClient().getChatCompletion(request);
return tryExtractContent(response);
}
private Optional<String> tryExtractContent(OpenAIChatCompletionResponse response) {
return response
.getChoices()
.stream()