ProxyAI/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java
Rene Leonhardt f7702286f3 Update to latest 233 platform 2023.3.6 (#439)
* Update to latest 233 platform 2023.3.6

* Use first non-blank choice from response
2024-04-12 17:31:59 +03:00

246 lines
11 KiB
Java

package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.settings.service.ServiceType.ANTHROPIC;
import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestProvider;
import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails;
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
import ee.carlrobert.codegpt.completions.llama.PromptTemplate;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMessage;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponse;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoice;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoiceDelta;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import okhttp3.Request;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;
@Service
public final class CompletionRequestService {
private static final Logger LOG = Logger.getInstance(CompletionRequestService.class);
private CompletionRequestService() {
}
public static CompletionRequestService getInstance() {
return ApplicationManager.getApplication().getService(CompletionRequestService.class);
}
public EventSource getCustomOpenAIChatCompletionAsync(
Request customRequest,
CompletionEventListener<String> eventListener) {
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
return EventSources.createFactory(httpClient).newEventSource(
customRequest,
new OpenAIChatCompletionEventSourceListener(eventListener));
}
public EventSource getChatCompletionAsync(
CallParameters callParameters,
CompletionEventListener<String> eventListener) {
var requestProvider = new CompletionRequestProvider(callParameters.getConversation());
return switch (GeneralSettings.getCurrentState().getSelectedService()) {
case OPENAI -> CompletionClientProvider.getOpenAIClient().getChatCompletionAsync(
requestProvider.buildOpenAIChatCompletionRequest(
OpenAISettings.getCurrentState().getModel(),
callParameters),
eventListener);
case CUSTOM_OPENAI -> getCustomOpenAIChatCompletionAsync(
requestProvider.buildCustomOpenAIChatCompletionRequest(
CustomServiceSettings.getCurrentState(),
callParameters),
eventListener);
case ANTHROPIC -> CompletionClientProvider.getClaudeClient().getCompletionAsync(
requestProvider.buildAnthropicChatCompletionRequest(callParameters),
eventListener);
case AZURE -> CompletionClientProvider.getAzureClient().getChatCompletionAsync(
requestProvider.buildOpenAIChatCompletionRequest(null, callParameters),
eventListener);
case YOU -> CompletionClientProvider.getYouClient().getChatCompletionAsync(
requestProvider.buildYouCompletionRequest(callParameters.getMessage()),
eventListener);
case LLAMA_CPP -> CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
requestProvider.buildLlamaCompletionRequest(
callParameters.getMessage(),
callParameters.getConversationType()),
eventListener);
default -> throw new IllegalArgumentException();
};
}
public EventSource getCodeCompletionAsync(
InfillRequestDetails requestDetails,
CompletionEventListener<String> eventListener) {
var requestProvider = new CodeCompletionRequestProvider(requestDetails);
return switch (GeneralSettings.getCurrentState().getSelectedService()) {
case OPENAI -> CompletionClientProvider.getOpenAIClient()
.getCompletionAsync(requestProvider.buildOpenAIRequest(), eventListener);
case LLAMA_CPP -> CompletionClientProvider.getLlamaClient()
.getChatCompletionAsync(requestProvider.buildLlamaRequest(), eventListener);
default ->
throw new IllegalArgumentException("Code completion not supported for selected service");
};
}
public void generateCommitMessageAsync(
String prompt,
CompletionEventListener<String> eventListener) {
var configuration = ConfigurationSettings.getCurrentState();
var commitMessagePrompt = configuration.getCommitMessagePrompt();
var openaiRequest = new OpenAIChatCompletionRequest.Builder(List.of(
new OpenAIChatCompletionStandardMessage("system", commitMessagePrompt),
new OpenAIChatCompletionStandardMessage("user", prompt)))
.setModel(OpenAISettings.getCurrentState().getModel())
.build();
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
switch (selectedService) {
case OPENAI:
CompletionClientProvider.getOpenAIClient()
.getChatCompletionAsync(openaiRequest, eventListener);
break;
case CUSTOM_OPENAI:
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
EventSources.createFactory(httpClient).newEventSource(
CompletionRequestProvider.buildCustomOpenAICompletionRequest(
commitMessagePrompt,
prompt),
new OpenAIChatCompletionEventSourceListener(eventListener));
break;
case ANTHROPIC:
var anthropicSettings = AnthropicSettings.getCurrentState();
var claudeRequest = new ClaudeCompletionRequest();
claudeRequest.setSystem(commitMessagePrompt);
claudeRequest.setStream(true);
claudeRequest.setMaxTokens(configuration.getMaxTokens());
claudeRequest.setModel(anthropicSettings.getModel());
claudeRequest.setMessages(List.of(new ClaudeCompletionStandardMessage("user", prompt)));
CompletionClientProvider.getClaudeClient()
.getCompletionAsync(claudeRequest, eventListener);
break;
case AZURE:
CompletionClientProvider.getAzureClient()
.getChatCompletionAsync(openaiRequest, eventListener);
break;
case LLAMA_CPP:
var settings = LlamaSettings.getCurrentState();
PromptTemplate promptTemplate;
if (settings.isRunLocalServer()) {
promptTemplate = settings.isUseCustomModel()
? settings.getLocalModelPromptTemplate()
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel())
.getPromptTemplate();
} else {
promptTemplate = settings.getRemoteModelPromptTemplate();
}
var finalPrompt = promptTemplate.buildPrompt(commitMessagePrompt, prompt, List.of());
CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
new LlamaCompletionRequest.Builder(finalPrompt)
.setN_predict(configuration.getMaxTokens())
.setTemperature(configuration.getTemperature())
.setTop_k(settings.getTopK())
.setTop_p(settings.getTopP())
.setMin_p(settings.getMinP())
.setRepeat_penalty(settings.getRepeatPenalty())
.build(), eventListener);
break;
default:
LOG.debug("Unknown service: {}", selectedService);
break;
}
}
public Optional<String> getLookupCompletion(String prompt) {
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
if (selectedService == YOU || selectedService == LLAMA_CPP) {
return Optional.empty();
}
if (selectedService == CUSTOM_OPENAI) {
var request = CompletionRequestProvider.buildCustomOpenAILookupCompletionRequest(prompt);
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
try (var response = httpClient.newCall(request).execute()) {
return tryExtractContent(
DeserializationUtil.mapResponse(response, OpenAIChatCompletionResponse.class));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
var request = CompletionRequestProvider.buildOpenAILookupCompletionRequest(prompt);
var response = selectedService == OPENAI
? CompletionClientProvider.getOpenAIClient().getChatCompletion(request)
: CompletionClientProvider.getAzureClient().getChatCompletion(request);
return tryExtractContent(response);
}
public boolean isRequestAllowed() {
return isRequestAllowed(GeneralSettings.getCurrentState().getSelectedService());
}
public static boolean isRequestAllowed(ServiceType serviceType) {
if (serviceType == OPENAI
&& CredentialsStore.INSTANCE.isCredentialSet(CredentialKey.OPENAI_API_KEY)) {
return true;
}
var azureCredentialKey = AzureSettings.getCurrentState().isUseAzureApiKeyAuthentication()
? CredentialKey.AZURE_OPENAI_API_KEY
: CredentialKey.AZURE_ACTIVE_DIRECTORY_TOKEN;
if (serviceType == AZURE && CredentialsStore.INSTANCE.isCredentialSet(azureCredentialKey)) {
return true;
}
return List.of(LLAMA_CPP, ANTHROPIC, CUSTOM_OPENAI).contains(serviceType);
}
/**
* Content of the first choice.
* <ul>
* <li>Search all choices which are not null</li>
* <li>Search all messages which are not null</li>
* <li>Use first content which is not null or blank (whitespace)</li>
* </ul>
*
* @return First non-blank content or {@code Optional.empty()}
*/
private Optional<String> tryExtractContent(OpenAIChatCompletionResponse response) {
return response
.getChoices()
.stream()
.filter(Objects::nonNull)
.map(OpenAIChatCompletionResponseChoice::getMessage)
.filter(Objects::nonNull)
.map(OpenAIChatCompletionResponseChoiceDelta::getContent)
.filter(c -> c != null && !c.isBlank())
.findFirst();
}
}