mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-09 19:45:16 +00:00
feat: streaming support for o1 models
This commit is contained in:
parent
55172e53ff
commit
e23aac95b4
3 changed files with 17 additions and 67 deletions
|
|
@ -3,9 +3,6 @@ package ee.carlrobert.codegpt.completions;
|
|||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.components.Service;
|
||||
import com.intellij.openapi.diagnostic.Logger;
|
||||
import com.intellij.openapi.progress.ProgressIndicator;
|
||||
import com.intellij.openapi.progress.ProgressManager;
|
||||
import com.intellij.openapi.progress.Task;
|
||||
import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest;
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore;
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey;
|
||||
|
|
@ -16,10 +13,8 @@ import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
|
|||
import ee.carlrobert.llm.client.DeserializationUtil;
|
||||
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
|
||||
import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest;
|
||||
import ee.carlrobert.llm.client.codegpt.response.CodeGPTException;
|
||||
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
|
||||
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
|
||||
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
|
||||
|
|
@ -30,15 +25,12 @@ import ee.carlrobert.llm.completion.CompletionEventListener;
|
|||
import ee.carlrobert.llm.completion.CompletionRequest;
|
||||
import java.io.IOException;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Stream;
|
||||
import javax.swing.SwingUtilities;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSources;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
@Service
|
||||
public final class CompletionRequestService {
|
||||
|
|
@ -100,13 +92,8 @@ public final class CompletionRequestService {
|
|||
CompletionEventListener<String> eventListener) {
|
||||
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
|
||||
return switch (GeneralSettings.getSelectedService()) {
|
||||
case OPENAI -> {
|
||||
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
|
||||
yield getO1ChatCompletionAsync(completionRequest, eventListener);
|
||||
}
|
||||
yield CompletionClientProvider.getOpenAIClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
}
|
||||
case OPENAI -> CompletionClientProvider.getOpenAIClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
case AZURE -> CompletionClientProvider.getAzureClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
case OLLAMA -> CompletionClientProvider.getOllamaClient()
|
||||
|
|
@ -115,9 +102,6 @@ public final class CompletionRequestService {
|
|||
};
|
||||
}
|
||||
if (request instanceof ChatCompletionRequest completionRequest) {
|
||||
if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) {
|
||||
return getO1ChatCompletionAsync(completionRequest, eventListener);
|
||||
}
|
||||
return CompletionClientProvider.getCodeGPTClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
}
|
||||
|
|
@ -146,38 +130,6 @@ public final class CompletionRequestService {
|
|||
throw new IllegalStateException("Unknown request type: " + request.getClass());
|
||||
}
|
||||
|
||||
private EventSource getO1ChatCompletionAsync(
|
||||
CompletionRequest request,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
ProgressManager.getInstance()
|
||||
.run(new Task.Backgroundable(null, "CodeGPT: Processing o1 request") {
|
||||
@Override
|
||||
public void run(@NotNull ProgressIndicator indicator) {
|
||||
indicator.setIndeterminate(true);
|
||||
try {
|
||||
var response = CompletionRequestService.getInstance().getChatCompletion(request);
|
||||
SwingUtilities.invokeLater(
|
||||
() -> eventListener.onComplete(new StringBuilder(response)));
|
||||
} catch (CodeGPTException e) {
|
||||
SwingUtilities.invokeLater(
|
||||
() -> eventListener.onError(new ErrorDetails(e.getDetail()), e));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return new EventSource() {
|
||||
@Override
|
||||
public @NotNull Request request() {
|
||||
return new Request.Builder().build(); // dummy
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel() {
|
||||
eventListener.onCancelled(new StringBuilder("Cancelled"));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public String getChatCompletion(CompletionRequest request) {
|
||||
if (request instanceof OpenAIChatCompletionRequest completionRequest) {
|
||||
var response = switch (GeneralSettings.getSelectedService()) {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
|
|||
ChatCompletionRequest.Builder(buildOpenAIMessages(model, params))
|
||||
.setModel(model)
|
||||
.setSessionId(params.sessionId)
|
||||
.setStream(true)
|
||||
.setMetadata(
|
||||
Metadata(
|
||||
CodeGPTPlugin.getVersion(),
|
||||
|
|
@ -29,12 +30,10 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
|
|||
|
||||
if ("o1-mini" == model || "o1-preview" == model) {
|
||||
requestBuilder
|
||||
.setStream(false)
|
||||
.setMaxTokens(null)
|
||||
.setTemperature(null)
|
||||
} else {
|
||||
requestBuilder
|
||||
.setStream(true)
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
.setTemperature(configuration.temperature.toDouble())
|
||||
}
|
||||
|
|
@ -66,7 +65,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
|
|||
): ChatCompletionRequest {
|
||||
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens)
|
||||
return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens, stream = stream)
|
||||
}
|
||||
|
||||
return ChatCompletionRequest.Builder(
|
||||
|
|
@ -84,7 +83,8 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
|
|||
model: String,
|
||||
prompt: String,
|
||||
systemPrompt: String = "",
|
||||
maxCompletionTokens: Int = 4096
|
||||
maxCompletionTokens: Int = 4096,
|
||||
stream: Boolean = false
|
||||
): ChatCompletionRequest {
|
||||
val messages = if (systemPrompt.isEmpty()) {
|
||||
listOf(OpenAIChatCompletionStandardMessage("user", prompt))
|
||||
|
|
@ -97,7 +97,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
|
|||
return ChatCompletionRequest.Builder(messages)
|
||||
.setModel(model)
|
||||
.setMaxTokens(maxCompletionTokens)
|
||||
.setStream(false)
|
||||
.setStream(stream)
|
||||
.setTemperature(null)
|
||||
.build()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,19 +24,16 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, params))
|
||||
.setModel(model)
|
||||
.setStream(true)
|
||||
.setMaxTokens(null)
|
||||
.setMaxCompletionTokens(configuration.maxTokens)
|
||||
if ("o1-mini" == model || "o1-preview" == model) {
|
||||
requestBuilder
|
||||
.setMaxCompletionTokens(configuration.maxTokens)
|
||||
.setStream(false)
|
||||
.setMaxTokens(null)
|
||||
.setTemperature(null)
|
||||
.setPresencePenalty(null)
|
||||
.setFrequencyPenalty(null)
|
||||
} else {
|
||||
requestBuilder
|
||||
.setStream(true)
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
.setTemperature(configuration.temperature.toDouble())
|
||||
requestBuilder.setTemperature(configuration.temperature.toDouble())
|
||||
}
|
||||
|
||||
return requestBuilder.build()
|
||||
|
|
@ -48,7 +45,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
val systemPrompt = service<PromptsSettings>().state.coreActions.editCode.instructions
|
||||
?: CoreActionsState.DEFAULT_EDIT_CODE_PROMPT
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
return buildBasicO1Request(model, prompt, systemPrompt)
|
||||
return buildBasicO1Request(model, prompt, systemPrompt, stream = true)
|
||||
}
|
||||
return createBasicCompletionRequest(systemPrompt, prompt, model, true)
|
||||
}
|
||||
|
|
@ -57,7 +54,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
val model = service<OpenAISettings>().state.model
|
||||
val (gitDiff, systemPrompt) = params
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
return buildBasicO1Request(model, gitDiff, systemPrompt)
|
||||
return buildBasicO1Request(model, gitDiff, systemPrompt, stream = true)
|
||||
}
|
||||
return createBasicCompletionRequest(systemPrompt, gitDiff, model, true)
|
||||
}
|
||||
|
|
@ -84,7 +81,8 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
model: String,
|
||||
prompt: String,
|
||||
systemPrompt: String = "",
|
||||
maxCompletionTokens: Int = 4096
|
||||
maxCompletionTokens: Int = 4096,
|
||||
stream: Boolean = false,
|
||||
): OpenAIChatCompletionRequest {
|
||||
val messages = if (systemPrompt.isEmpty()) {
|
||||
listOf(OpenAIChatCompletionStandardMessage("user", prompt))
|
||||
|
|
@ -97,11 +95,11 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
return OpenAIChatCompletionRequest.Builder(messages)
|
||||
.setModel(model)
|
||||
.setMaxCompletionTokens(maxCompletionTokens)
|
||||
.setStream(false)
|
||||
.setMaxTokens(null)
|
||||
.setStream(stream)
|
||||
.setTemperature(null)
|
||||
.setFrequencyPenalty(null)
|
||||
.setPresencePenalty(null)
|
||||
.setMaxTokens(null)
|
||||
.build()
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue