diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index a68e1b61..d7c79ab7 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -12,7 +12,7 @@ jsoup = "1.17.2" jtokkit = "1.0.0" junit = "5.10.2" kotlin = "1.9.24" -llm-client = "0.8.1" +llm-client = "0.8.2" okio = "3.9.0" tree-sitter = "0.22.5" diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index 376df97e..9b41a453 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -112,13 +112,19 @@ public class CompletionRequestProvider { } public static OpenAIChatCompletionRequest buildOpenAILookupCompletionRequest(String context) { + return buildOpenAILookupCompletionRequest(context, OpenAISettings.getCurrentState().getModel()); + } + + public static OpenAIChatCompletionRequest buildOpenAILookupCompletionRequest( + String context, + String model) { return new OpenAIChatCompletionRequest.Builder( List.of( new OpenAIChatCompletionStandardMessage( "system", getResourceContent("/prompts/method-name-generator.txt")), new OpenAIChatCompletionStandardMessage("user", context))) - .setModel(OpenAISettings.getCurrentState().getModel()) + .setModel(model) .setStream(false) .build(); } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 8c079b4b..06a517c8 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -1,10 +1,5 @@ package ee.carlrobert.codegpt.completions; -import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI; -import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP; -import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI; -import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU; - import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; import com.intellij.openapi.diagnostic.Logger; @@ -45,9 +40,11 @@ import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionR import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoiceDelta; import ee.carlrobert.llm.completion.CompletionEventListener; import java.io.IOException; +import java.util.Collection; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.stream.Stream; import okhttp3.Request; import okhttp3.sse.EventSource; import okhttp3.sse.EventSources; @@ -164,20 +161,29 @@ public final class CompletionRequestService { String gitDiff, CompletionEventListener eventListener) { var configuration = ConfigurationSettings.getCurrentState(); - var openaiRequest = new Builder(List.of( + var openaiRequestBuilder = new Builder(List.of( new OpenAIChatCompletionStandardMessage("system", systemPrompt), new OpenAIChatCompletionStandardMessage("user", gitDiff))) - .setModel(OpenAISettings.getCurrentState().getModel()) - .build(); + .setModel(OpenAISettings.getCurrentState().getModel()); var selectedService = GeneralSettings.getCurrentState().getSelectedService(); switch (selectedService) { case CODEGPT: - CompletionClientProvider.getCodeGPTClient() - .getChatCompletionAsync(openaiRequest, eventListener); + CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync( + openaiRequestBuilder + .setModel( + ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class) + .getState() + .getChatCompletionSettings() + .getModel()) + .build(), + eventListener); break; case OPENAI: - CompletionClientProvider.getOpenAIClient() - .getChatCompletionAsync(openaiRequest, eventListener); + CompletionClientProvider.getOpenAIClient().getChatCompletionAsync( + openaiRequestBuilder + .setModel(OpenAISettings.getCurrentState().getModel()) + .build(), + eventListener); break; case CUSTOM_OPENAI: var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); @@ -200,7 +206,7 @@ public final class CompletionRequestService { break; case AZURE: CompletionClientProvider.getAzureClient() - .getChatCompletionAsync(openaiRequest, eventListener); + .getChatCompletionAsync(openaiRequestBuilder.build(), eventListener); break; case LLAMA_CPP: var settings = LlamaSettings.getCurrentState(); @@ -260,27 +266,35 @@ public final class CompletionRequestService { } public Optional getLookupCompletion(String prompt) { + var openaiRequest = CompletionRequestProvider.buildOpenAILookupCompletionRequest(prompt); var selectedService = GeneralSettings.getCurrentState().getSelectedService(); - if (selectedService == YOU || selectedService == LLAMA_CPP) { - return Optional.empty(); - } - - if (selectedService == CUSTOM_OPENAI) { - var request = CompletionRequestProvider.buildCustomOpenAILookupCompletionRequest(prompt); - var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); - try (var response = httpClient.newCall(request).execute()) { + switch (selectedService) { + case CODEGPT: + var model = ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class) + .getState() + .getChatCompletionSettings() + .getModel(); return tryExtractContent( - DeserializationUtil.mapResponse(response, OpenAIChatCompletionResponse.class)); - } catch (IOException e) { - throw new RuntimeException(e); - } + CompletionClientProvider.getCodeGPTClient().getChatCompletion( + CompletionRequestProvider.buildOpenAILookupCompletionRequest(prompt, model))); + case OPENAI: + return tryExtractContent( + CompletionClientProvider.getOpenAIClient().getChatCompletion(openaiRequest)); + case AZURE: + return tryExtractContent( + CompletionClientProvider.getAzureClient().getChatCompletion(openaiRequest)); + case CUSTOM_OPENAI: + var request = CompletionRequestProvider.buildCustomOpenAILookupCompletionRequest(prompt); + var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); + try (var response = httpClient.newCall(request).execute()) { + return tryExtractContent( + DeserializationUtil.mapResponse(response, OpenAIChatCompletionResponse.class)); + } catch (IOException e) { + throw new RuntimeException(e); + } + default: + return Optional.empty(); } - - var request = CompletionRequestProvider.buildOpenAILookupCompletionRequest(prompt); - var response = selectedService == OPENAI - ? CompletionClientProvider.getOpenAIClient().getChatCompletion(request) - : CompletionClientProvider.getAzureClient().getChatCompletion(request); - return tryExtractContent(response); } public boolean isRequestAllowed() { @@ -311,9 +325,8 @@ public final class CompletionRequestService { * @return First non-blank content or {@code Optional.empty()} */ private Optional tryExtractContent(OpenAIChatCompletionResponse response) { - return response - .getChoices() - .stream() + return Stream.ofNullable(response.getChoices()) + .flatMap(Collection::stream) .filter(Objects::nonNull) .map(OpenAIChatCompletionResponseChoice::getMessage) .filter(Objects::nonNull)