diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index ed494da2..a2ded89a 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -3,9 +3,6 @@ package ee.carlrobert.codegpt.completions; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; import com.intellij.openapi.diagnostic.Logger; -import com.intellij.openapi.progress.ProgressIndicator; -import com.intellij.openapi.progress.ProgressManager; -import com.intellij.openapi.progress.Task; import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest; import ee.carlrobert.codegpt.credentials.CredentialsStore; import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey; @@ -16,10 +13,8 @@ import ee.carlrobert.codegpt.settings.service.google.GoogleSettings; import ee.carlrobert.llm.client.DeserializationUtil; import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest; import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest; -import ee.carlrobert.llm.client.codegpt.response.CodeGPTException; import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest; import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; -import ee.carlrobert.llm.client.openai.completion.ErrorDetails; import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener; import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener; import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest; @@ -30,15 +25,12 @@ import ee.carlrobert.llm.completion.CompletionEventListener; import ee.carlrobert.llm.completion.CompletionRequest; import java.io.IOException; import java.util.Collection; -import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.stream.Stream; -import javax.swing.SwingUtilities; import okhttp3.Request; import okhttp3.sse.EventSource; import okhttp3.sse.EventSources; -import org.jetbrains.annotations.NotNull; @Service public final class CompletionRequestService { @@ -100,13 +92,8 @@ public final class CompletionRequestService { CompletionEventListener eventListener) { if (request instanceof OpenAIChatCompletionRequest completionRequest) { return switch (GeneralSettings.getSelectedService()) { - case OPENAI -> { - if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) { - yield getO1ChatCompletionAsync(completionRequest, eventListener); - } - yield CompletionClientProvider.getOpenAIClient() - .getChatCompletionAsync(completionRequest, eventListener); - } + case OPENAI -> CompletionClientProvider.getOpenAIClient() + .getChatCompletionAsync(completionRequest, eventListener); case AZURE -> CompletionClientProvider.getAzureClient() .getChatCompletionAsync(completionRequest, eventListener); case OLLAMA -> CompletionClientProvider.getOllamaClient() @@ -115,9 +102,6 @@ public final class CompletionRequestService { }; } if (request instanceof ChatCompletionRequest completionRequest) { - if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) { - return getO1ChatCompletionAsync(completionRequest, eventListener); - } return CompletionClientProvider.getCodeGPTClient() .getChatCompletionAsync(completionRequest, eventListener); } @@ -146,38 +130,6 @@ public final class CompletionRequestService { throw new IllegalStateException("Unknown request type: " + request.getClass()); } - private EventSource getO1ChatCompletionAsync( - CompletionRequest request, - CompletionEventListener eventListener) { - ProgressManager.getInstance() - .run(new Task.Backgroundable(null, "CodeGPT: Processing o1 request") { - @Override - public void run(@NotNull ProgressIndicator indicator) { - indicator.setIndeterminate(true); - try { - var response = CompletionRequestService.getInstance().getChatCompletion(request); - SwingUtilities.invokeLater( - () -> eventListener.onComplete(new StringBuilder(response))); - } catch (CodeGPTException e) { - SwingUtilities.invokeLater( - () -> eventListener.onError(new ErrorDetails(e.getDetail()), e)); - } - } - }); - - return new EventSource() { - @Override - public @NotNull Request request() { - return new Request.Builder().build(); // dummy - } - - @Override - public void cancel() { - eventListener.onCancelled(new StringBuilder("Cancelled")); - } - }; - } - public String getChatCompletion(CompletionRequest request) { if (request instanceof OpenAIChatCompletionRequest completionRequest) { var response = switch (GeneralSettings.getSelectedService()) { diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt index c3641fc0..f61792a3 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt @@ -20,6 +20,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() { ChatCompletionRequest.Builder(buildOpenAIMessages(model, params)) .setModel(model) .setSessionId(params.sessionId) + .setStream(true) .setMetadata( Metadata( CodeGPTPlugin.getVersion(), @@ -29,12 +30,10 @@ class CodeGPTRequestFactory : BaseRequestFactory() { if ("o1-mini" == model || "o1-preview" == model) { requestBuilder - .setStream(false) .setMaxTokens(null) .setTemperature(null) } else { requestBuilder - .setStream(true) .setMaxTokens(configuration.maxTokens) .setTemperature(configuration.temperature.toDouble()) } @@ -66,7 +65,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() { ): ChatCompletionRequest { val model = service().state.chatCompletionSettings.model if (model == "o1-mini" || model == "o1-preview") { - return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens) + return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens, stream = stream) } return ChatCompletionRequest.Builder( @@ -84,7 +83,8 @@ class CodeGPTRequestFactory : BaseRequestFactory() { model: String, prompt: String, systemPrompt: String = "", - maxCompletionTokens: Int = 4096 + maxCompletionTokens: Int = 4096, + stream: Boolean = false ): ChatCompletionRequest { val messages = if (systemPrompt.isEmpty()) { listOf(OpenAIChatCompletionStandardMessage("user", prompt)) @@ -97,7 +97,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() { return ChatCompletionRequest.Builder(messages) .setModel(model) .setMaxTokens(maxCompletionTokens) - .setStream(false) + .setStream(stream) .setTemperature(null) .build() } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt index 3c96dafa..6701b951 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt @@ -24,19 +24,16 @@ class OpenAIRequestFactory : CompletionRequestFactory { val requestBuilder: OpenAIChatCompletionRequest.Builder = OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, params)) .setModel(model) + .setStream(true) + .setMaxTokens(null) + .setMaxCompletionTokens(configuration.maxTokens) if ("o1-mini" == model || "o1-preview" == model) { requestBuilder - .setMaxCompletionTokens(configuration.maxTokens) - .setStream(false) - .setMaxTokens(null) .setTemperature(null) .setPresencePenalty(null) .setFrequencyPenalty(null) } else { - requestBuilder - .setStream(true) - .setMaxTokens(configuration.maxTokens) - .setTemperature(configuration.temperature.toDouble()) + requestBuilder.setTemperature(configuration.temperature.toDouble()) } return requestBuilder.build() @@ -48,7 +45,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { val systemPrompt = service().state.coreActions.editCode.instructions ?: CoreActionsState.DEFAULT_EDIT_CODE_PROMPT if (model == "o1-mini" || model == "o1-preview") { - return buildBasicO1Request(model, prompt, systemPrompt) + return buildBasicO1Request(model, prompt, systemPrompt, stream = true) } return createBasicCompletionRequest(systemPrompt, prompt, model, true) } @@ -57,7 +54,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { val model = service().state.model val (gitDiff, systemPrompt) = params if (model == "o1-mini" || model == "o1-preview") { - return buildBasicO1Request(model, gitDiff, systemPrompt) + return buildBasicO1Request(model, gitDiff, systemPrompt, stream = true) } return createBasicCompletionRequest(systemPrompt, gitDiff, model, true) } @@ -84,7 +81,8 @@ class OpenAIRequestFactory : CompletionRequestFactory { model: String, prompt: String, systemPrompt: String = "", - maxCompletionTokens: Int = 4096 + maxCompletionTokens: Int = 4096, + stream: Boolean = false, ): OpenAIChatCompletionRequest { val messages = if (systemPrompt.isEmpty()) { listOf(OpenAIChatCompletionStandardMessage("user", prompt)) @@ -97,11 +95,11 @@ class OpenAIRequestFactory : CompletionRequestFactory { return OpenAIChatCompletionRequest.Builder(messages) .setModel(model) .setMaxCompletionTokens(maxCompletionTokens) - .setStream(false) + .setMaxTokens(null) + .setStream(stream) .setTemperature(null) .setFrequencyPenalty(null) .setPresencePenalty(null) - .setMaxTokens(null) .build() }