feat: streaming support for o1 models

This commit is contained in:
Carl-Robert Linnupuu 2024-12-05 20:46:03 +00:00
parent 55172e53ff
commit e23aac95b4
3 changed files with 17 additions and 67 deletions

View file

@ -3,9 +3,6 @@ package ee.carlrobert.codegpt.completions;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.progress.ProgressIndicator;
import com.intellij.openapi.progress.ProgressManager;
import com.intellij.openapi.progress.Task;
import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey;
@ -16,10 +13,8 @@ import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest;
import ee.carlrobert.llm.client.codegpt.response.CodeGPTException;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
@ -30,15 +25,12 @@ import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionRequest;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;
import javax.swing.SwingUtilities;
import okhttp3.Request;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;
import org.jetbrains.annotations.NotNull;
@Service
public final class CompletionRequestService {
@ -100,13 +92,8 @@ public final class CompletionRequestService {
CompletionEventListener<String> eventListener) {
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
return switch (GeneralSettings.getSelectedService()) {
case OPENAI -> {
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
yield getO1ChatCompletionAsync(completionRequest, eventListener);
}
yield CompletionClientProvider.getOpenAIClient()
.getChatCompletionAsync(completionRequest, eventListener);
}
case OPENAI -> CompletionClientProvider.getOpenAIClient()
.getChatCompletionAsync(completionRequest, eventListener);
case AZURE -> CompletionClientProvider.getAzureClient()
.getChatCompletionAsync(completionRequest, eventListener);
case OLLAMA -> CompletionClientProvider.getOllamaClient()
@ -115,9 +102,6 @@ public final class CompletionRequestService {
};
}
if (request instanceof ChatCompletionRequest completionRequest) {
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
return getO1ChatCompletionAsync(completionRequest, eventListener);
}
return CompletionClientProvider.getCodeGPTClient()
.getChatCompletionAsync(completionRequest, eventListener);
}
@ -146,38 +130,6 @@ public final class CompletionRequestService {
throw new IllegalStateException("Unknown request type: " + request.getClass());
}
private EventSource getO1ChatCompletionAsync(
CompletionRequest request,
CompletionEventListener<String> eventListener) {
ProgressManager.getInstance()
.run(new Task.Backgroundable(null, "CodeGPT: Processing o1 request") {
@Override
public void run(@NotNull ProgressIndicator indicator) {
indicator.setIndeterminate(true);
try {
var response = CompletionRequestService.getInstance().getChatCompletion(request);
SwingUtilities.invokeLater(
() -> eventListener.onComplete(new StringBuilder(response)));
} catch (CodeGPTException e) {
SwingUtilities.invokeLater(
() -> eventListener.onError(new ErrorDetails(e.getDetail()), e));
}
}
});
return new EventSource() {
@Override
public @NotNull Request request() {
return new Request.Builder().build(); // dummy
}
@Override
public void cancel() {
eventListener.onCancelled(new StringBuilder("Cancelled"));
}
};
}
public String getChatCompletion(CompletionRequest request) {
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
var response = switch (GeneralSettings.getSelectedService()) {

View file

@ -20,6 +20,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
ChatCompletionRequest.Builder(buildOpenAIMessages(model, params))
.setModel(model)
.setSessionId(params.sessionId)
.setStream(true)
.setMetadata(
Metadata(
CodeGPTPlugin.getVersion(),
@ -29,12 +30,10 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
if ("o1-mini" == model || "o1-preview" == model) {
requestBuilder
.setStream(false)
.setMaxTokens(null)
.setTemperature(null)
} else {
requestBuilder
.setStream(true)
.setMaxTokens(configuration.maxTokens)
.setTemperature(configuration.temperature.toDouble())
}
@ -66,7 +65,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
): ChatCompletionRequest {
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
if (model == "o1-mini" || model == "o1-preview") {
return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens)
return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens, stream = stream)
}
return ChatCompletionRequest.Builder(
@ -84,7 +83,8 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
model: String,
prompt: String,
systemPrompt: String = "",
maxCompletionTokens: Int = 4096
maxCompletionTokens: Int = 4096,
stream: Boolean = false
): ChatCompletionRequest {
val messages = if (systemPrompt.isEmpty()) {
listOf(OpenAIChatCompletionStandardMessage("user", prompt))
@ -97,7 +97,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
return ChatCompletionRequest.Builder(messages)
.setModel(model)
.setMaxTokens(maxCompletionTokens)
.setStream(false)
.setStream(stream)
.setTemperature(null)
.build()
}

View file

@ -24,19 +24,16 @@ class OpenAIRequestFactory : CompletionRequestFactory {
val requestBuilder: OpenAIChatCompletionRequest.Builder =
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, params))
.setModel(model)
.setStream(true)
.setMaxTokens(null)
.setMaxCompletionTokens(configuration.maxTokens)
if ("o1-mini" == model || "o1-preview" == model) {
requestBuilder
.setMaxCompletionTokens(configuration.maxTokens)
.setStream(false)
.setMaxTokens(null)
.setTemperature(null)
.setPresencePenalty(null)
.setFrequencyPenalty(null)
} else {
requestBuilder
.setStream(true)
.setMaxTokens(configuration.maxTokens)
.setTemperature(configuration.temperature.toDouble())
requestBuilder.setTemperature(configuration.temperature.toDouble())
}
return requestBuilder.build()
@ -48,7 +45,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
val systemPrompt = service<PromptsSettings>().state.coreActions.editCode.instructions
?: CoreActionsState.DEFAULT_EDIT_CODE_PROMPT
if (model == "o1-mini" || model == "o1-preview") {
return buildBasicO1Request(model, prompt, systemPrompt)
return buildBasicO1Request(model, prompt, systemPrompt, stream = true)
}
return createBasicCompletionRequest(systemPrompt, prompt, model, true)
}
@ -57,7 +54,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
val model = service<OpenAISettings>().state.model
val (gitDiff, systemPrompt) = params
if (model == "o1-mini" || model == "o1-preview") {
return buildBasicO1Request(model, gitDiff, systemPrompt)
return buildBasicO1Request(model, gitDiff, systemPrompt, stream = true)
}
return createBasicCompletionRequest(systemPrompt, gitDiff, model, true)
}
@ -84,7 +81,8 @@ class OpenAIRequestFactory : CompletionRequestFactory {
model: String,
prompt: String,
systemPrompt: String = "",
maxCompletionTokens: Int = 4096
maxCompletionTokens: Int = 4096,
stream: Boolean = false,
): OpenAIChatCompletionRequest {
val messages = if (systemPrompt.isEmpty()) {
listOf(OpenAIChatCompletionStandardMessage("user", prompt))
@ -97,11 +95,11 @@ class OpenAIRequestFactory : CompletionRequestFactory {
return OpenAIChatCompletionRequest.Builder(messages)
.setModel(model)
.setMaxCompletionTokens(maxCompletionTokens)
.setStream(false)
.setMaxTokens(null)
.setStream(stream)
.setTemperature(null)
.setFrequencyPenalty(null)
.setPresencePenalty(null)
.setMaxTokens(null)
.build()
}