diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 5de7ed07..b0102be8 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -12,7 +12,7 @@ jsoup = "1.17.2" jtokkit = "1.1.0" junit = "5.11.0" kotlin = "2.0.0" -llm-client = "0.8.18" +llm-client = "0.8.19" okio = "3.9.0" tree-sitter = "0.22.6a" diff --git a/src/main/java/ee/carlrobert/codegpt/Icons.java b/src/main/java/ee/carlrobert/codegpt/Icons.java index ca0adce3..d86d3d88 100644 --- a/src/main/java/ee/carlrobert/codegpt/Icons.java +++ b/src/main/java/ee/carlrobert/codegpt/Icons.java @@ -15,10 +15,12 @@ public final class Icons { public static final Icon Azure = IconLoader.getIcon("/icons/azure.svg", Icons.class); public static final Icon Databricks = IconLoader.getIcon("/icons/dbrx.svg", Icons.class); public static final Icon DeepSeek = IconLoader.getIcon("/icons/deepseek.png", Icons.class); + public static final Icon Qwen = IconLoader.getIcon("/icons/qwen.png", Icons.class); public static final Icon Google = IconLoader.getIcon("/icons/google.svg", Icons.class); public static final Icon Llama = IconLoader.getIcon("/icons/llama.svg", Icons.class); public static final Icon OpenAI = IconLoader.getIcon("/icons/openai.svg", Icons.class); public static final Icon Meta = IconLoader.getIcon("/icons/meta.svg", Icons.class); + public static final Icon Mistral = IconLoader.getIcon("/icons/mistral.svg", Icons.class); public static final Icon Send = IconLoader.getIcon("/icons/send.svg", Icons.class); public static final Icon Sparkle = IconLoader.getIcon("/icons/sparkle.svg", Icons.class); public static final Icon You = IconLoader.getIcon("/icons/you.svg", Icons.class); diff --git a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java index 2ed218df..0b771932 100644 --- a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java +++ b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java @@ -18,6 +18,7 @@ import com.intellij.vcs.commit.CommitWorkflowUi; import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.EncodingManager; import ee.carlrobert.codegpt.Icons; +import ee.carlrobert.codegpt.completions.CommitMessageRequestParameters; import ee.carlrobert.codegpt.completions.CompletionRequestService; import ee.carlrobert.codegpt.settings.configuration.CommitMessageTemplate; import ee.carlrobert.codegpt.ui.OverlayUtil; @@ -85,8 +86,9 @@ public class GenerateGitCommitMessageAction extends AnAction { var commitWorkflowUi = event.getData(VcsDataKeys.COMMIT_WORKFLOW_UI); if (commitWorkflowUi != null) { CompletionRequestService.getInstance().getCommitMessageAsync( - project.getService(CommitMessageTemplate.class).getSystemPrompt(), - gitDiff, + new CommitMessageRequestParameters( + gitDiff, + project.getService(CommitMessageTemplate.class).getSystemPrompt()), getEventListener(project, commitWorkflowUi)); } } @@ -162,11 +164,22 @@ public class GenerateGitCommitMessageAction extends AnAction { @Override public void onMessage(String message, EventSource eventSource) { messageBuilder.append(message); - var application = ApplicationManager.getApplication(); - application.invokeLater(() -> - application.runWriteAction(() -> - WriteCommandAction.runWriteCommandAction(project, () -> - commitWorkflowUi.getCommitMessageUi().setText(messageBuilder.toString())))); + updateCommitMessage(messageBuilder.toString()); + } + + @Override + public void onComplete(StringBuilder result) { + if (messageBuilder.isEmpty()) { + updateCommitMessage(result.toString()); + } + } + + private void updateCommitMessage(String message) { + ApplicationManager.getApplication().invokeLater(() -> + WriteCommandAction.runWriteCommandAction(project, () -> + commitWorkflowUi.getCommitMessageUi().setText(message) + ) + ); } @Override diff --git a/src/main/java/ee/carlrobert/codegpt/completions/ChatCompletionEventListener.java b/src/main/java/ee/carlrobert/codegpt/completions/ChatCompletionEventListener.java new file mode 100644 index 00000000..65725f87 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/ChatCompletionEventListener.java @@ -0,0 +1,74 @@ +package ee.carlrobert.codegpt.completions; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import ee.carlrobert.codegpt.events.CodeGPTEvent; +import ee.carlrobert.codegpt.telemetry.TelemetryAction; +import ee.carlrobert.llm.client.openai.completion.ErrorDetails; +import ee.carlrobert.llm.completion.CompletionEventListener; +import okhttp3.sse.EventSource; + +public class ChatCompletionEventListener implements CompletionEventListener { + + private final CallParameters callParameters; + private final CompletionResponseEventListener eventListener; + private final StringBuilder messageBuilder = new StringBuilder(); + + public ChatCompletionEventListener( + CallParameters callParameters, + CompletionResponseEventListener eventListener) { + this.callParameters = callParameters; + this.eventListener = eventListener; + } + + @Override + public void onEvent(String data) { + try { + var event = new ObjectMapper().readValue(data, CodeGPTEvent.class); + eventListener.handleCodeGPTEvent(event); + } catch (JsonProcessingException e) { + // ignore + } + } + + @Override + public void onMessage(String message, EventSource eventSource) { + messageBuilder.append(message); + callParameters.getMessage().setResponse(messageBuilder.toString()); + eventListener.handleMessage(message); + } + + @Override + public void onComplete(StringBuilder messageBuilder) { + eventListener.handleCompleted(messageBuilder.toString(), callParameters); + } + + @Override + public void onCancelled(StringBuilder messageBuilder) { + eventListener.handleCompleted(messageBuilder.toString(), callParameters); + } + + @Override + public void onError(ErrorDetails error, Throwable ex) { + try { + eventListener.handleError(error, ex); + } finally { + sendError(error, ex); + } + } + + private void sendError(ErrorDetails error, Throwable ex) { + var telemetryMessage = TelemetryAction.COMPLETION_ERROR.createActionMessage(); + if ("insufficient_quota".equals(error.getCode())) { + telemetryMessage + .property("type", "USER") + .property("code", "INSUFFICIENT_QUOTA"); + } else { + telemetryMessage + .property("conversationId", callParameters.getConversation().getId().toString()) + .property("model", callParameters.getConversation().getModel()) + .error(new RuntimeException(error.toString(), ex)); + } + telemetryMessage.send(); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java index fe7c796d..2989dcaf 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java @@ -89,7 +89,6 @@ public class CompletionClientProvider { return builder.build(getDefaultClientBuilder()); } - public static GoogleClient getGoogleClient() { return new GoogleClient.Builder(getCredential(CredentialKey.GOOGLE_API_KEY)) .build(getDefaultClientBuilder()); diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java deleted file mode 100644 index a66e8163..00000000 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java +++ /dev/null @@ -1,129 +0,0 @@ -package ee.carlrobert.codegpt.completions; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import ee.carlrobert.codegpt.events.CodeGPTEvent; -import ee.carlrobert.codegpt.settings.GeneralSettings; -import ee.carlrobert.codegpt.telemetry.TelemetryAction; -import ee.carlrobert.llm.client.openai.completion.ErrorDetails; -import ee.carlrobert.llm.completion.CompletionEventListener; -import okhttp3.sse.EventSource; - -public class CompletionRequestHandler { - - private final StringBuilder messageBuilder = new StringBuilder(); - private final CompletionResponseEventListener completionResponseEventListener; - private EventSource eventSource; - - public CompletionRequestHandler(CompletionResponseEventListener completionResponseEventListener) { - this.completionResponseEventListener = completionResponseEventListener; - } - - public void call(CallParameters callParameters) { - try { - eventSource = startCall(callParameters, new RequestCompletionEventListener(callParameters)); - } catch (TotalUsageExceededException e) { - completionResponseEventListener.handleTokensExceeded( - callParameters.getConversation(), - callParameters.getMessage()); - } finally { - sendInfo(callParameters); - } - } - - public void cancel() { - if (eventSource != null) { - eventSource.cancel(); - } - } - - private EventSource startCall( - CallParameters callParameters, - CompletionEventListener eventListener) { - try { - return CompletionRequestService.getInstance() - .getChatCompletionAsync(callParameters, eventListener); - } catch (Throwable ex) { - handleCallException(ex); - throw ex; - } - } - - private void handleCallException(Throwable ex) { - var errorMessage = "Something went wrong"; - if (ex instanceof TotalUsageExceededException) { - errorMessage = - "The length of the context exceeds the maximum limit that the model can handle. " - + "Try reducing the input message or maximum completion token size."; - } - completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex); - } - - class RequestCompletionEventListener implements CompletionEventListener { - - private final CallParameters callParameters; - - public RequestCompletionEventListener(CallParameters callParameters) { - this.callParameters = callParameters; - } - - @Override - public void onEvent(String data) { - try { - var event = new ObjectMapper().readValue(data, CodeGPTEvent.class); - completionResponseEventListener.handleCodeGPTEvent(event); - } catch (JsonProcessingException e) { - // ignore - } - } - - @Override - public void onMessage(String message, EventSource eventSource) { - messageBuilder.append(message); - callParameters.getMessage().setResponse(messageBuilder.toString()); - completionResponseEventListener.handleMessage(message); - } - - @Override - public void onComplete(StringBuilder messageBuilder) { - completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters); - } - - @Override - public void onCancelled(StringBuilder messageBuilder) { - completionResponseEventListener.handleCompleted(messageBuilder.toString(), callParameters); - } - - @Override - public void onError(ErrorDetails error, Throwable ex) { - try { - completionResponseEventListener.handleError(error, ex); - } finally { - sendError(error, ex); - } - } - - private void sendError(ErrorDetails error, Throwable ex) { - var telemetryMessage = TelemetryAction.COMPLETION_ERROR.createActionMessage(); - if ("insufficient_quota".equals(error.getCode())) { - telemetryMessage - .property("type", "USER") - .property("code", "INSUFFICIENT_QUOTA"); - } else { - telemetryMessage - .property("conversationId", callParameters.getConversation().getId().toString()) - .property("model", callParameters.getConversation().getModel()) - .error(new RuntimeException(error.toString(), ex)); - } - telemetryMessage.send(); - } - } - - private void sendInfo(CallParameters callParameters) { - TelemetryAction.COMPLETION.createActionMessage() - .property("conversationId", callParameters.getConversation().getId().toString()) - .property("model", callParameters.getConversation().getModel()) - .property("service", GeneralSettings.getSelectedService().getCode().toLowerCase()) - .send(); - } -} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 3ba8e44a..56c06c5b 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -3,7 +3,9 @@ package ee.carlrobert.codegpt.completions; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; import com.intellij.openapi.diagnostic.Logger; -import ee.carlrobert.codegpt.actions.editor.EditCodeRequestParams; +import com.intellij.openapi.progress.ProgressIndicator; +import com.intellij.openapi.progress.ProgressManager; +import com.intellij.openapi.progress.Task; import ee.carlrobert.codegpt.completions.factory.CustomOpenAIRequest; import ee.carlrobert.codegpt.credentials.CredentialsStore; import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey; @@ -26,12 +28,15 @@ import ee.carlrobert.llm.completion.CompletionEventListener; import ee.carlrobert.llm.completion.CompletionRequest; import java.io.IOException; import java.util.Collection; +import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.stream.Stream; +import javax.swing.SwingUtilities; import okhttp3.Request; import okhttp3.sse.EventSource; import okhttp3.sse.EventSources; +import org.jetbrains.annotations.NotNull; @Service public final class CompletionRequestService { @@ -63,50 +68,50 @@ public final class CompletionRequestService { new OpenAIChatCompletionEventSourceListener(eventListener)); } - public String getLookupCompletion(String prompt) { - return getChatCompletion( - CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) - .createLookupRequest(prompt)); + public String getLookupCompletion(LookupRequestCallParameters params) { + var request = CompletionRequestFactory + .getFactory(GeneralSettings.getSelectedService()) + .createLookupRequest(params); + return getChatCompletion(request); } public EventSource getCommitMessageAsync( - String systemPrompt, - String gitDiff, + CommitMessageRequestParameters params, CompletionEventListener eventListener) { - return getChatCompletionAsync( - CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) - .createCommitMessageRequest(systemPrompt, gitDiff), - eventListener); + var request = CompletionRequestFactory + .getFactory(GeneralSettings.getSelectedService()) + .createCommitMessageRequest(params); + return getChatCompletionAsync(request, eventListener); } public EventSource getEditCodeCompletionAsync( - EditCodeRequestParams params, + EditCodeRequestParameters params, CompletionEventListener eventListener) { - var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText()); - return getChatCompletionAsync( - CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) - .createEditCodeRequest(input), - eventListener); + var request = CompletionRequestFactory + .getFactory(GeneralSettings.getSelectedService()) + .createEditCodeRequest(params); + return getChatCompletionAsync(request, eventListener); } public EventSource getChatCompletionAsync( - CallParameters callParameters, - CompletionEventListener eventListener) { - return getChatCompletionAsync( - CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) - .createChatRequest(callParameters), - eventListener); - } - - private EventSource getChatCompletionAsync( CompletionRequest request, CompletionEventListener eventListener) { if (request instanceof OpenAIChatCompletionRequest completionRequest) { return switch (GeneralSettings.getSelectedService()) { - case CODEGPT -> CompletionClientProvider.getCodeGPTClient() - .getChatCompletionAsync(completionRequest, eventListener); - case OPENAI -> CompletionClientProvider.getOpenAIClient() - .getChatCompletionAsync(completionRequest, eventListener); + case CODEGPT -> { + if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) { + yield getO1ChatCompletionAsync(completionRequest, eventListener); + } + yield CompletionClientProvider.getCodeGPTClient() + .getChatCompletionAsync(completionRequest, eventListener); + } + case OPENAI -> { + if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) { + yield getO1ChatCompletionAsync(completionRequest, eventListener); + } + yield CompletionClientProvider.getOpenAIClient() + .getChatCompletionAsync(completionRequest, eventListener); + } case AZURE -> CompletionClientProvider.getAzureClient() .getChatCompletionAsync(completionRequest, eventListener); default -> throw new RuntimeException("Unknown service selected"); @@ -142,7 +147,33 @@ public final class CompletionRequestService { throw new IllegalStateException("Unknown request type: " + request.getClass()); } - private String getChatCompletion(CompletionRequest request) { + private EventSource getO1ChatCompletionAsync( + OpenAIChatCompletionRequest request, + CompletionEventListener eventListener) { + ProgressManager.getInstance() + .run(new Task.Backgroundable(null, "CodeGPT: Processing o1 request") { + @Override + public void run(@NotNull ProgressIndicator indicator) { + indicator.setIndeterminate(true); + var response = CompletionRequestService.getInstance().getChatCompletion(request); + SwingUtilities.invokeLater(() -> eventListener.onComplete(new StringBuilder(response))); + } + }); + + return new EventSource() { + @Override + public @NotNull Request request() { + return new Request.Builder().build(); // dummy + } + + @Override + public void cancel() { + eventListener.onCancelled(new StringBuilder("Cancelled")); + } + }; + } + + public String getChatCompletion(CompletionRequest request) { if (request instanceof OpenAIChatCompletionRequest completionRequest) { var response = switch (GeneralSettings.getSelectedService()) { case CODEGPT -> CompletionClientProvider.getCodeGPTClient() diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionResponseEventListener.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionResponseEventListener.java index 367af192..b0538b59 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionResponseEventListener.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionResponseEventListener.java @@ -16,6 +16,9 @@ public interface CompletionResponseEventListener { default void handleTokensExceeded(Conversation conversation, Message message) { } + default void handleCompleted(String fullMessage) { + } + default void handleCompleted(String fullMessage, CallParameters callParameters) { } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java b/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java index 53e2f8ea..309bb1c4 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java @@ -56,7 +56,8 @@ public class MethodNameLookupListener implements LookupManagerListener { Application application, String prompt) { try { - var response = CompletionRequestService.getInstance().getLookupCompletion(prompt); + var response = CompletionRequestService.getInstance() + .getLookupCompletion(new LookupRequestCallParameters(prompt)); if (!response.isEmpty()) { for (var value : response.split(",")) { application.invokeLater(() -> application.runReadAction(() -> { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/ToolwindowChatCompletionRequestHandler.java b/src/main/java/ee/carlrobert/codegpt/completions/ToolwindowChatCompletionRequestHandler.java new file mode 100644 index 00000000..d0c125bf --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/ToolwindowChatCompletionRequestHandler.java @@ -0,0 +1,67 @@ +package ee.carlrobert.codegpt.completions; + +import ee.carlrobert.codegpt.settings.GeneralSettings; +import ee.carlrobert.codegpt.telemetry.TelemetryAction; +import ee.carlrobert.llm.client.openai.completion.ErrorDetails; +import okhttp3.sse.EventSource; + +public class ToolwindowChatCompletionRequestHandler { + + private final CompletionResponseEventListener completionResponseEventListener; + private EventSource eventSource; + + public ToolwindowChatCompletionRequestHandler( + CompletionResponseEventListener completionResponseEventListener) { + this.completionResponseEventListener = completionResponseEventListener; + } + + public void call(CallParameters callParameters) { + try { + eventSource = startCall(callParameters); + } catch (TotalUsageExceededException e) { + completionResponseEventListener.handleTokensExceeded( + callParameters.getConversation(), + callParameters.getMessage()); + } finally { + sendInfo(callParameters); + } + } + + public void cancel() { + if (eventSource != null) { + eventSource.cancel(); + } + } + + private EventSource startCall(CallParameters callParameters) { + try { + var request = CompletionRequestFactory + .getFactory(GeneralSettings.getSelectedService()) + .createChatRequest(new ChatCompletionRequestParameters(callParameters)); + return CompletionRequestService.getInstance().getChatCompletionAsync( + request, + new ChatCompletionEventListener(callParameters, completionResponseEventListener)); + } catch (Throwable ex) { + handleCallException(ex); + throw ex; + } + } + + private void handleCallException(Throwable ex) { + var errorMessage = "Something went wrong"; + if (ex instanceof TotalUsageExceededException) { + errorMessage = + "The length of the context exceeds the maximum limit that the model can handle. " + + "Try reducing the input message or maximum completion token size."; + } + completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex); + } + + private void sendInfo(CallParameters callParameters) { + TelemetryAction.COMPLETION.createActionMessage() + .property("conversationId", callParameters.getConversation().getId().toString()) + .property("model", callParameters.getConversation().getModel()) + .property("service", GeneralSettings.getSelectedService().getCode().toLowerCase()) + .send(); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/advanced/AdvancedSettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/advanced/AdvancedSettingsState.java index 9acdaa7c..21f1fa86 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/advanced/AdvancedSettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/advanced/AdvancedSettingsState.java @@ -11,8 +11,8 @@ public class AdvancedSettingsState { private boolean proxyAuthSelected; private String proxyUsername; private String proxyPassword; - private int connectTimeout = 30; - private int readTimeout = 30; + private int connectTimeout = 120; + private int readTimeout = 120; public String getProxyHost() { return proxyHost; diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java index ba271bf8..964d6c33 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java @@ -15,10 +15,10 @@ import ee.carlrobert.codegpt.CodeGPTKeys; import ee.carlrobert.codegpt.ReferencedFile; import ee.carlrobert.codegpt.actions.ActionType; import ee.carlrobert.codegpt.completions.CallParameters; -import ee.carlrobert.codegpt.completions.CompletionRequestHandler; import ee.carlrobert.codegpt.completions.CompletionRequestService; import ee.carlrobert.codegpt.completions.CompletionRequestUtil; import ee.carlrobert.codegpt.completions.ConversationType; +import ee.carlrobert.codegpt.completions.ToolwindowChatCompletionRequestHandler; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.ConversationService; import ee.carlrobert.codegpt.conversations.message.Message; @@ -60,7 +60,7 @@ public class ChatToolWindowTabPanel implements Disposable { private final TotalTokensPanel totalTokensPanel; private final ChatToolWindowScrollablePanel toolWindowScrollablePanel; - private @Nullable CompletionRequestHandler requestHandler; + private @Nullable ToolwindowChatCompletionRequestHandler requestHandler; public ChatToolWindowTabPanel(@NotNull Project project, @NotNull Conversation conversation) { this.project = project; @@ -250,7 +250,7 @@ public class ChatToolWindowTabPanel implements Disposable { return; } - requestHandler = new CompletionRequestHandler( + requestHandler = new ToolwindowChatCompletionRequestHandler( new ToolWindowCompletionResponseEventListener( conversationService, responsePanel, diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java index 0f8b7fe3..72e2f1e9 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java @@ -112,6 +112,9 @@ abstract class ToolWindowCompletionResponseEventListener implements try { responsePanel.enableActions(); responseContainer.enableActions(); + if (!responseContainer.isResponseReceived() && !fullMessage.isEmpty()) { + responseContainer.withResponse(fullMessage); + } totalTokensPanel.updateUserPromptTokens(textArea.getText()); totalTokensPanel.updateConversationTokens(callParameters.getConversation()); } finally { diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java index e8694881..88e535d6 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java @@ -113,6 +113,10 @@ public class ChatMessageResponseBody extends JPanel { } public ChatMessageResponseBody withResponse(String response) { + if (!responseReceived) { + removeAll(); + } + for (var message : MarkdownUtil.splitCodeBlocks(response)) { currentlyProcessedEditorPanel = null; currentlyProcessedTextPane = null; @@ -362,4 +366,8 @@ public class ChatMessageResponseBody extends JPanel { panel.add(listPanel, BorderLayout.CENTER); return panel; } + + public boolean isResponseReceived() { + return responseReceived; + } } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java index 3d1bcef2..cb69bbdd 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java @@ -123,9 +123,10 @@ public class ModelComboBoxAction extends ComboBoxAction { var openaiGroup = DefaultActionGroup.createPopupGroup(() -> "OpenAI"); openaiGroup.getTemplatePresentation().setIcon(Icons.OpenAI); List.of( + OpenAIChatCompletionModel.O_1_PREVIEW, + OpenAIChatCompletionModel.O_1_MINI, OpenAIChatCompletionModel.GPT_4_O, OpenAIChatCompletionModel.GPT_4_O_MINI, - OpenAIChatCompletionModel.GPT_4_VISION_PREVIEW, OpenAIChatCompletionModel.GPT_4_0125_128k) .forEach(model -> openaiGroup.add(createOpenAIModelAction(model, presentation))); actionGroup.add(openaiGroup); diff --git a/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeCompletionListener.kt b/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeCompletionListener.kt index 9de71dee..a1f99582 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeCompletionListener.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeCompletionListener.kt @@ -34,7 +34,12 @@ class EditCodeCompletionListener( } override fun onComplete(messageBuilder: StringBuilder) { - runInEdt { cleanupAndFormat() } + runInEdt { + if (replacedLength == 0 && messageBuilder.isNotEmpty()) { + handleDiff(messageBuilder.toString()) + } + cleanupAndFormat() + } observableProperties.loading.set(false) } @@ -73,7 +78,6 @@ class EditCodeCompletionListener( val document = editor.document val startOffset = selectionTextRange.startOffset val endOffset = selectionTextRange.endOffset - runUndoTransparentWriteAction { val remainingOriginalLength = endOffset - (startOffset + replacedLength) if (remainingOriginalLength > 0) { diff --git a/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt b/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt index 656eab48..ad5f076b 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt @@ -9,10 +9,9 @@ import com.intellij.openapi.util.TextRange import com.intellij.openapi.util.text.StringUtil import com.jetbrains.rd.util.AtomicReference import ee.carlrobert.codegpt.completions.CompletionRequestService +import ee.carlrobert.codegpt.completions.EditCodeRequestParameters import ee.carlrobert.codegpt.ui.ObservableProperties -data class EditCodeRequestParams(val prompt: String, val selectedText: String) - class EditCodeSubmissionHandler( private val editor: Editor, private val observableProperties: ObservableProperties, @@ -36,7 +35,7 @@ class EditCodeSubmissionHandler( runInEdt { editor.selectionModel.removeSelection() } service().getEditCodeCompletionAsync( - EditCodeRequestParams(userPrompt, selectedText), + EditCodeRequestParameters(userPrompt, selectedText), EditCodeCompletionListener(editor, observableProperties, selectionTextRange) ) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionCallParameters.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionCallParameters.kt new file mode 100644 index 00000000..53736ad3 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionCallParameters.kt @@ -0,0 +1,19 @@ +package ee.carlrobert.codegpt.completions + +interface CompletionCallParameters + +data class ChatCompletionRequestParameters( + val callParameters: CallParameters +) : CompletionCallParameters + +data class CommitMessageRequestParameters( + val gitDiff: String, + val systemPrompt: String +) : CompletionCallParameters + +data class LookupRequestCallParameters(val prompt: String) : CompletionCallParameters + +data class EditCodeRequestParameters( + val prompt: String, + val selectedText: String +) : CompletionCallParameters \ No newline at end of file diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt index 3aa6a112..465b9a24 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt @@ -7,10 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.llm.completion.CompletionRequest interface CompletionRequestFactory { - fun createChatRequest(callParameters: CallParameters): CompletionRequest - fun createEditCodeRequest(input: String): CompletionRequest - fun createCommitMessageRequest(systemPrompt: String, gitDiff: String): CompletionRequest - fun createLookupRequest(prompt: String): CompletionRequest + fun createChatRequest(params: ChatCompletionRequestParameters): CompletionRequest + fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest + fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest + fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest companion object { @JvmStatic @@ -30,24 +30,23 @@ interface CompletionRequestFactory { } abstract class BaseRequestFactory : CompletionRequestFactory { - override fun createEditCodeRequest(input: String): CompletionRequest { - return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, true) + override fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest { + val prompt = "${params.prompt}\n\n${params.selectedText}" + return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, 8192, true) } - override fun createCommitMessageRequest( - systemPrompt: String, - gitDiff: String - ): CompletionRequest { - return createBasicCompletionRequest(systemPrompt, gitDiff, true) + override fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest { + return createBasicCompletionRequest(params.systemPrompt, params.gitDiff, 512, true) } - override fun createLookupRequest(prompt: String): CompletionRequest { - return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt) + override fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest { + return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, params.prompt, 512) } abstract fun createBasicCompletionRequest( systemPrompt: String, userPrompt: String, + maxTokens: Int = 4096, stream: Boolean = false ): CompletionRequest } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt index 201ad51a..532d43a4 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt @@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest @@ -10,10 +10,10 @@ import ee.carlrobert.llm.completion.CompletionRequest class AzureRequestFactory : BaseRequestFactory() { - override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest { val configuration = service().state val requestBuilder: OpenAIChatCompletionRequest.Builder = - OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, callParameters)) + OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, params.callParameters)) .setMaxTokens(configuration.maxTokens) .setStream(true) .setTemperature(configuration.temperature.toDouble()) @@ -23,6 +23,7 @@ class AzureRequestFactory : BaseRequestFactory() { override fun createBasicCompletionRequest( systemPrompt: String, userPrompt: String, + maxTokens: Int, stream: Boolean ): CompletionRequest { return OpenAIRequestFactory.createBasicCompletionRequest( diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt index c37bedb4..ca3a4307 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt @@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.persona.PersonaSettings import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings @@ -11,7 +11,8 @@ import ee.carlrobert.llm.completion.CompletionRequest class ClaudeRequestFactory : BaseRequestFactory() { - override fun createChatRequest(callParameters: CallParameters): ClaudeCompletionRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): ClaudeCompletionRequest { + val (callParameters) = params return ClaudeCompletionRequest().apply { model = service().state.model maxTokens = service().state.maxTokens @@ -57,15 +58,16 @@ class ClaudeRequestFactory : BaseRequestFactory() { override fun createBasicCompletionRequest( systemPrompt: String, userPrompt: String, + maxTokens: Int, stream: Boolean ): CompletionRequest { return ClaudeCompletionRequest().apply { system = systemPrompt isStream = stream - maxTokens = service().state.maxTokens model = service().state.model messages = listOf(ClaudeCompletionStandardMessage("user", userPrompt)) + this.maxTokens = maxTokens } } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt index d53efa56..bb34586d 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt @@ -2,7 +2,8 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters +import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildBasicO1Request import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings @@ -11,15 +12,26 @@ import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDe class CodeGPTRequestFactory : BaseRequestFactory() { - override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest { + val (callParameters) = params val model = service().state.chatCompletionSettings.model val configuration = service().state val requestBuilder: OpenAIChatCompletionRequest.Builder = OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) .setModel(model) - .setMaxTokens(configuration.maxTokens) + if ("o1-mini" == model || "o1-preview" == model) { + requestBuilder + .setMaxCompletionTokens(configuration.maxTokens) + .setStream(false) + .setMaxTokens(null) + .setTemperature(null) + } else { + requestBuilder .setStream(true) + .setMaxTokens(configuration.maxTokens) .setTemperature(configuration.temperature.toDouble()) + } + if (callParameters.message.isWebSearchIncluded) { requestBuilder.setWebSearchIncluded(true) } @@ -36,12 +48,17 @@ class CodeGPTRequestFactory : BaseRequestFactory() { override fun createBasicCompletionRequest( systemPrompt: String, userPrompt: String, + maxTokens: Int, stream: Boolean ): OpenAIChatCompletionRequest { + val model = service().state.chatCompletionSettings.model + if (model == "o1-mini" || model == "o1-preview") { + return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens) + } return OpenAIRequestFactory.createBasicCompletionRequest( systemPrompt, userPrompt, - service().state.chatCompletionSettings.model, + model, stream ) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt index a118572e..1532d449 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt @@ -3,7 +3,7 @@ package ee.carlrobert.codegpt.completions.factory import com.fasterxml.jackson.databind.ObjectMapper import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState @@ -19,7 +19,8 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest class CustomOpenAIRequestFactory : BaseRequestFactory() { - override fun createChatRequest(callParameters: CallParameters): CustomOpenAIRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): CustomOpenAIRequest { + val (callParameters) = params val request = buildCustomOpenAIChatCompletionRequest( service() .state @@ -34,6 +35,7 @@ class CustomOpenAIRequestFactory : BaseRequestFactory() { override fun createBasicCompletionRequest( systemPrompt: String, userPrompt: String, + maxTokens: Int, stream: Boolean ): CompletionRequest { val request = buildCustomOpenAIChatCompletionRequest( diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt index be66b669..5d8245eb 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt @@ -2,11 +2,8 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.EncodingManager -import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.* import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT -import ee.carlrobert.codegpt.completions.ConversationType -import ee.carlrobert.codegpt.completions.TotalUsageExceededException import ee.carlrobert.codegpt.conversations.ConversationsState import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.persona.PersonaSettings @@ -23,7 +20,8 @@ import java.nio.file.Path class GoogleRequestFactory : BaseRequestFactory() { - override fun createChatRequest(callParameters: CallParameters): GoogleCompletionRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): GoogleCompletionRequest { + val (callParameters) = params val configuration = service().state val messages = buildGoogleMessages(service().state.model, callParameters) return GoogleCompletionRequest.Builder(messages) @@ -38,6 +36,7 @@ class GoogleRequestFactory : BaseRequestFactory() { override fun createBasicCompletionRequest( systemPrompt: String, userPrompt: String, + maxTokens: Int, stream: Boolean ): GoogleCompletionRequest { val configuration = service().state @@ -50,7 +49,7 @@ class GoogleRequestFactory : BaseRequestFactory() { ) .generationConfig( GoogleGenerationConfig.Builder() - .maxOutputTokens(configuration.maxTokens) + .maxOutputTokens(maxTokens) .temperature(configuration.temperature.toDouble()).build() ) .build() diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt index 0ed0a2e0..8cc95637 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt @@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.ConversationType import ee.carlrobert.codegpt.completions.llama.LlamaModel @@ -14,7 +14,8 @@ import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest class LlamaRequestFactory : BaseRequestFactory() { - override fun createChatRequest(callParameters: CallParameters): LlamaCompletionRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): LlamaCompletionRequest { + val (callParameters) = params val promptTemplate = getPromptTemplate() val systemPrompt = if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS) @@ -33,6 +34,7 @@ class LlamaRequestFactory : BaseRequestFactory() { override fun createBasicCompletionRequest( systemPrompt: String, userPrompt: String, + maxTokens: Int, stream: Boolean ): LlamaCompletionRequest { val promptTemplate = getPromptTemplate() diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt index 58942c06..0dd2aa01 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt @@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory import ee.carlrobert.codegpt.completions.CallParameters +import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.ConversationType import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings @@ -18,7 +19,8 @@ import java.util.* class OllamaRequestFactory : BaseRequestFactory() { - override fun createChatRequest(callParameters: CallParameters): OllamaChatCompletionRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): OllamaChatCompletionRequest { + val (callParameters) = params val configuration = service().state val settings = service().state return OllamaChatCompletionRequest.Builder( @@ -38,6 +40,7 @@ class OllamaRequestFactory : BaseRequestFactory() { override fun createBasicCompletionRequest( systemPrompt: String, userPrompt: String, + maxTokens: Int, stream: Boolean ): OllamaChatCompletionRequest { return OllamaChatCompletionRequest.Builder( diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt index f235f179..56598fad 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt @@ -2,13 +2,10 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.EncodingManager -import ee.carlrobert.codegpt.completions.CallParameters -import ee.carlrobert.codegpt.completions.CompletionRequestFactory +import ee.carlrobert.codegpt.completions.* import ee.carlrobert.codegpt.completions.CompletionRequestUtil.EDIT_CODE_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.CompletionRequestUtil.GENERATE_METHOD_NAMES_SYSTEM_PROMPT -import ee.carlrobert.codegpt.completions.ConversationType -import ee.carlrobert.codegpt.completions.TotalUsageExceededException import ee.carlrobert.codegpt.conversations.ConversationsState import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings.Companion.getState @@ -17,62 +14,93 @@ import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings import ee.carlrobert.codegpt.util.file.FileUtil.getImageMediaType import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel import ee.carlrobert.llm.client.openai.completion.request.* -import ee.carlrobert.llm.completion.CompletionRequest import java.io.IOException import java.nio.file.Files import java.nio.file.Path class OpenAIRequestFactory : CompletionRequestFactory { - override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest { + val (callParameters) = params val model = service().state.model val configuration = service().state val requestBuilder: OpenAIChatCompletionRequest.Builder = OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) .setModel(model) - .setMaxTokens(configuration.maxTokens) + if ("o1-mini" == model || "o1-preview" == model) { + requestBuilder + .setMaxCompletionTokens(configuration.maxTokens) + .setStream(false) + .setMaxTokens(null) + .setTemperature(null) + .setPresencePenalty(null) + .setFrequencyPenalty(null) + } else { + requestBuilder .setStream(true) + .setMaxTokens(configuration.maxTokens) .setTemperature(configuration.temperature.toDouble()) + } return requestBuilder.build() } - override fun createEditCodeRequest(input: String): OpenAIChatCompletionRequest { - return buildEditCodeRequest(input, service().state.model) + override fun createEditCodeRequest(params: EditCodeRequestParameters): OpenAIChatCompletionRequest { + val model = service().state.model + if (model == "o1-mini" || model == "o1-preview") { + return buildBasicO1Request(model, params.prompt, EDIT_CODE_SYSTEM_PROMPT) + } + return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, params.prompt, model, true) } - override fun createCommitMessageRequest( - systemPrompt: String, - gitDiff: String - ): CompletionRequest { - return createBasicCompletionRequest( - systemPrompt, - gitDiff, - service().state.model, - true - ) + override fun createCommitMessageRequest(params: CommitMessageRequestParameters): OpenAIChatCompletionRequest { + val model = service().state.model + val (gitDiff, systemPrompt) = params + if (model == "o1-mini" || model == "o1-preview") { + return buildBasicO1Request(model, gitDiff, systemPrompt) + } + return createBasicCompletionRequest(systemPrompt, gitDiff, model, true) } - override fun createLookupRequest(prompt: String): CompletionRequest { - return createBasicCompletionRequest( - GENERATE_METHOD_NAMES_SYSTEM_PROMPT, - prompt, - service().state.model - ) + override fun createLookupRequest(params: LookupRequestCallParameters): OpenAIChatCompletionRequest { + val model = service().state.model + val (prompt) = params + if (model == "o1-mini" || model == "o1-preview") { + return buildBasicO1Request(model, prompt, GENERATE_METHOD_NAMES_SYSTEM_PROMPT) + } + return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt, model) } companion object { - fun buildEditCodeRequest( - input: String, - model: String? = null + fun buildBasicO1Request( + model: String, + prompt: String, + systemPrompt: String = "", + maxCompletionTokens: Int = 4096 ): OpenAIChatCompletionRequest { - return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, model, true) + val messages = if (systemPrompt.isEmpty()) { + listOf(OpenAIChatCompletionStandardMessage("user", prompt)) + } else { + listOf( + OpenAIChatCompletionStandardMessage("user", systemPrompt), + OpenAIChatCompletionStandardMessage("user", prompt) + ) + } + return OpenAIChatCompletionRequest.Builder(messages) + .setModel(model) + .setMaxCompletionTokens(maxCompletionTokens) + .setStream(false) + .setTemperature(null) + .setFrequencyPenalty(null) + .setPresencePenalty(null) + .setMaxTokens(null) + .build() } fun buildOpenAIMessages( model: String?, callParameters: CallParameters ): List { - val messages = buildOpenAIMessages(callParameters) + val messages = buildOpenAIChatMessages(model, callParameters) if (model == null) { return messages @@ -104,21 +132,24 @@ class OpenAIRequestFactory : CompletionRequestFactory { ) } - private fun buildOpenAIMessages( + private fun buildOpenAIChatMessages( + model: String?, callParameters: CallParameters ): MutableList { val message = callParameters.message val messages = mutableListOf() + val role = if ("o1-mini" == model || "o1-preview" == model) "user" else "system" + if (callParameters.conversationType == ConversationType.DEFAULT) { val sessionPersonaDetails = callParameters.message.personaDetails if (callParameters.message.personaDetails == null) { messages.add( - OpenAIChatCompletionStandardMessage("system", getSystemPrompt()) + OpenAIChatCompletionStandardMessage(role, getSystemPrompt()) ) } else { messages.add( OpenAIChatCompletionStandardMessage( - "system", + role, sessionPersonaDetails.instructions ) ) @@ -126,7 +157,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { } if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS) { messages.add( - OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT) + OpenAIChatCompletionStandardMessage(role, FIX_COMPILE_ERRORS_SYSTEM_PROMPT) ) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt index bcbc9310..024bbc04 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt @@ -14,49 +14,46 @@ object CodeGPTAvailableModels { fun getToolWindowModels(pricingPlan: PricingPlan?): List { return when (pricingPlan) { null, ANONYMOUS -> listOf( - CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), - CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), - CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL), - CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), + CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE), + CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE), + CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE), + CodeGPTModel("DeepSeek Coder V2 - FREE", "deepseek-coder-v2", Icons.DeepSeek, ANONYMOUS), CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), - CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS) ) FREE -> listOf( - CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), - CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), - CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), - CodeGPTModel("Llama 3 (70B)", "llama-3-70b", Icons.Meta, FREE), - CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.CodeGPTModel, FREE), - CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE), + CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE), + CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE), + CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE), + CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, ANONYMOUS), + CodeGPTModel("Qwen 2.5 (72B)", "qwen-2.5-72b", Icons.Qwen, FREE), + CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.Mistral, FREE), ) INDIVIDUAL -> listOf( - CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE), + CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE), CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL), - CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), - CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL), - CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), - CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL), + CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE), + CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, FREE), ) } } @JvmStatic val ALL_CHAT_MODELS: List = listOf( - CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("o1-mini", "o1-mini", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, FREE), CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL), - CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), - CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL), - CodeGPTModel("Llama 3 (70B)", "llama-3-70b", Icons.Meta, FREE), - CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), - CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL), - CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS), - CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE), - CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.CodeGPTModel, FREE), - CodeGPTModel("DeepSeek Coder (33B)", "deepseek-coder-33b", Icons.CodeGPTModel, FREE), - CodeGPTModel("WizardLM-2 (8x22B)", "wizardlm-2-8x22b", Icons.CodeGPTModel, FREE) + CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, FREE), + CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, FREE), + CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, FREE), + CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.Mistral, FREE), + CodeGPTModel("Qwen 2.5 (72B)", "qwen-2.5-72b", Icons.Qwen, FREE), ) @JvmStatic @@ -65,7 +62,6 @@ object CodeGPTAvailableModels { CodeGPTModel("StarCoder (16B)", "starcoder-16b", Icons.CodeGPTModel, FREE), CodeGPTModel("StarCoder (7B) - FREE", "starcoder-7b", Icons.CodeGPTModel, FREE), CodeGPTModel("WizardCoder Python (34B)", "wizardcoder-python", Icons.CodeGPTModel, FREE), - CodeGPTModel("Phind Code LLaMA v2 (34B)", "phind-codellama", Icons.CodeGPTModel, FREE) ) @JvmStatic diff --git a/src/main/resources/icons/mistral.svg b/src/main/resources/icons/mistral.svg new file mode 100644 index 00000000..8b7bdd26 --- /dev/null +++ b/src/main/resources/icons/mistral.svg @@ -0,0 +1,32 @@ + + + Mistral AI + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/main/resources/icons/qwen.png b/src/main/resources/icons/qwen.png new file mode 100644 index 00000000..f4a9fb31 Binary files /dev/null and b/src/main/resources/icons/qwen.png differ diff --git a/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java b/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java deleted file mode 100644 index e69de29b..00000000 diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt index 5b6a8d94..bb2b6690 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt @@ -24,12 +24,14 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(secondMessage) val request = OpenAIRequestFactory().createChatRequest( - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false + ChatCompletionRequestParameters( + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + null, + false + ) ) ) @@ -55,12 +57,14 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(secondMessage) val request = OpenAIRequestFactory().createChatRequest( - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false + ChatCompletionRequestParameters( + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + null, + false + ) ) ) @@ -86,12 +90,14 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(secondMessage) val request = OpenAIRequestFactory().createChatRequest( - CallParameters( - conversation, - ConversationType.DEFAULT, - secondMessage, - null, - true + ChatCompletionRequestParameters( + CallParameters( + conversation, + ConversationType.DEFAULT, + secondMessage, + null, + true + ) ) ) @@ -118,12 +124,14 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.discardTokenLimits() val request = OpenAIRequestFactory().createChatRequest( - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false + ChatCompletionRequestParameters( + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + null, + false + ) ) ) @@ -146,12 +154,14 @@ class CompletionRequestProviderTest : IntegrationTest() { assertThrows(TotalUsageExceededException::class.java) { OpenAIRequestFactory().createChatRequest( - CallParameters( - conversation, - ConversationType.DEFAULT, - createDummyMessage(100), - null, - false + ChatCompletionRequestParameters( + CallParameters( + conversation, + ConversationType.DEFAULT, + createDummyMessage(100), + null, + false + ) ) ) } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt similarity index 92% rename from src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt rename to src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt index 4a9b20ff..87484628 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt @@ -14,14 +14,17 @@ import org.apache.http.HttpHeaders import org.assertj.core.api.Assertions.assertThat import testsupport.IntegrationTest -class DefaultCompletionRequestHandlerTest : IntegrationTest() { +class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() { fun testOpenAIChatCompletionCall() { useOpenAIService() service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" val message = Message("TEST_PROMPT") val conversation = ConversationService.getInstance().startConversation() - val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + val requestHandler = + ToolwindowChatCompletionRequestHandler( + getRequestEventListener(message) + ) expectOpenAI(StreamHttpExchange { request: RequestEntity -> assertThat(request.uri.path).isEqualTo("/v1/chat/completions") assertThat(request.method).isEqualTo("POST") @@ -77,7 +80,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() { jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) }) val message = Message("TEST_PROMPT") - val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + val requestHandler = + ToolwindowChatCompletionRequestHandler( + getRequestEventListener(message) + ) requestHandler.call(CallParameters(conversation, message)) @@ -91,7 +97,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() { val message = Message("TEST_PROMPT") val conversation = ConversationService.getInstance().startConversation() conversation.addMessage(Message("Ping", "Pong")) - val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + val requestHandler = + ToolwindowChatCompletionRequestHandler( + getRequestEventListener(message) + ) expectLlama(StreamHttpExchange { request: RequestEntity -> assertThat(request.uri.path).isEqualTo("/completion") assertThat(request.body) @@ -125,7 +134,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() { service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" val message = Message("TEST_PROMPT") val conversation = ConversationService.getInstance().startConversation() - val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + val requestHandler = + ToolwindowChatCompletionRequestHandler( + getRequestEventListener(message) + ) expectOllama(NdJsonStreamHttpExchange { request: RequestEntity -> assertThat(request.uri.path).isEqualTo("/api/chat") assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") @@ -171,7 +183,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() { service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" val message = Message("TEST_PROMPT") val conversation = ConversationService.getInstance().startConversation() - val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + val requestHandler = + ToolwindowChatCompletionRequestHandler( + getRequestEventListener(message) + ) expectGoogle(StreamHttpExchange { request: RequestEntity -> assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent") assertThat(request.method).isEqualTo("POST") @@ -207,7 +222,10 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() { service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" val message = Message("TEST_PROMPT") val conversation = ConversationService.getInstance().startConversation() - val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + val requestHandler = + ToolwindowChatCompletionRequestHandler( + getRequestEventListener(message) + ) expectCodeGPT(StreamHttpExchange { request: RequestEntity -> assertThat(request.uri.path).isEqualTo("/v1/chat/completions") assertThat(request.method).isEqualTo("POST")