From 5ad9bcfaff64ce7e54bda0b855d91cee4ef78a02 Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Thu, 17 Oct 2024 02:24:57 +0300 Subject: [PATCH] refactor: improve chat completion call handling --- .../GenerateGitCommitMessageAction.java | 4 +- .../codegpt/completions/CallParameters.java | 93 ---- .../ChatCompletionEventListener.java | 4 +- .../completions/CompletionRequestService.java | 6 +- .../CompletionResponseEventListener.java | 2 +- .../completions/MethodNameLookupListener.java | 2 +- ...oolwindowChatCompletionRequestHandler.java | 8 +- .../conversations/ConversationService.java | 6 +- .../chat/ChatToolWindowTabPanel.java | 174 +++--- ...WindowCompletionResponseEventListener.java | 16 +- .../chat/ui/ChatMessageResponseBody.java | 34 +- .../toolwindow/chat/ui/UserMessagePanel.java | 4 + .../editor/EditCodeSubmissionHandler.kt | 4 +- .../completions/CompletionCallParameters.kt | 19 - .../completions/CompletionParameters.kt | 87 +++ .../completions/CompletionRequestFactory.kt | 16 +- .../factory/AzureRequestFactory.kt | 7 +- .../factory/ClaudeRequestFactory.kt | 19 +- .../factory/CodeGPTRequestFactory.kt | 17 +- .../factory/CustomOpenAIRequestFactory.kt | 11 +- .../factory/GoogleRequestFactory.kt | 29 +- .../factory/LlamaRequestFactory.kt | 11 +- .../factory/OllamaRequestFactory.kt | 24 +- .../factory/OpenAIRequestFactory.kt | 19 +- .../CompletionRequestProviderTest.kt | 49 +- ...lwindowChatCompletionRequestHandlerTest.kt | 507 +++++++++--------- .../chat/ChatToolWindowTabPanelTest.kt | 6 +- 27 files changed, 568 insertions(+), 610 deletions(-) delete mode 100644 src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java delete mode 100644 src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionCallParameters.kt create mode 100644 src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionParameters.kt diff --git a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java index 6a18b7cd..091691a5 100644 --- a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java +++ b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java @@ -28,7 +28,7 @@ import com.intellij.openapi.vfs.VirtualFile; import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.EncodingManager; import ee.carlrobert.codegpt.Icons; -import ee.carlrobert.codegpt.completions.CommitMessageRequestParameters; +import ee.carlrobert.codegpt.completions.CommitMessageCompletionParameters; import ee.carlrobert.codegpt.completions.CompletionRequestService; import ee.carlrobert.codegpt.settings.configuration.CommitMessageTemplate; import ee.carlrobert.codegpt.ui.OverlayUtil; @@ -96,7 +96,7 @@ public class GenerateGitCommitMessageAction extends AnAction { if (editor != null) { ((EditorEx) editor).setCaretVisible(false); CompletionRequestService.getInstance().getCommitMessageAsync( - new CommitMessageRequestParameters( + new CommitMessageCompletionParameters( gitDiff, project.getService(CommitMessageTemplate.class).getSystemPrompt()), getEventListener(project, editor.getDocument())); diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java b/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java deleted file mode 100644 index 791c7404..00000000 --- a/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java +++ /dev/null @@ -1,93 +0,0 @@ -package ee.carlrobert.codegpt.completions; - -import ee.carlrobert.codegpt.ReferencedFile; -import ee.carlrobert.codegpt.conversations.Conversation; -import ee.carlrobert.codegpt.conversations.message.Message; -import java.util.List; -import java.util.UUID; -import org.jetbrains.annotations.Nullable; - -public class CallParameters { - - private final UUID sessionId; - private final Conversation conversation; - private final ConversationType conversationType; - private final Message message; - private final boolean retry; - private final String highlightedText; - private String imageMediaType; - private byte[] imageData; - private List referencedFiles; - - public CallParameters(Conversation conversation, Message message) { - this(null, conversation, message); - } - - public CallParameters(UUID sessionId, Conversation conversation, Message message) { - this(sessionId, conversation, ConversationType.DEFAULT, message, null, false); - } - - // TODO: Builder - public CallParameters( - UUID sessionId, - Conversation conversation, - ConversationType conversationType, - Message message, - @Nullable String highlightedText, - boolean retry) { - this.sessionId = sessionId; - this.conversation = conversation; - this.conversationType = conversationType; - this.message = message; - this.highlightedText = highlightedText; - this.retry = retry; - } - - public UUID getSessionId() { - return sessionId; - } - - public Conversation getConversation() { - return conversation; - } - - public ConversationType getConversationType() { - return conversationType; - } - - public Message getMessage() { - return message; - } - - public boolean isRetry() { - return retry; - } - - public @Nullable String getImageMediaType() { - return imageMediaType; - } - - public void setImageMediaType(@Nullable String imageMediaType) { - this.imageMediaType = imageMediaType; - } - - public byte[] getImageData() { - return imageData; - } - - public void setImageData(byte[] imageData) { - this.imageData = imageData; - } - - public @Nullable String getHighlightedText() { - return highlightedText; - } - - public @Nullable List getReferencedFiles() { - return referencedFiles; - } - - public void setReferencedFiles(List referencedFiles) { - this.referencedFiles = referencedFiles; - } -} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/ChatCompletionEventListener.java b/src/main/java/ee/carlrobert/codegpt/completions/ChatCompletionEventListener.java index 65725f87..f515b677 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/ChatCompletionEventListener.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/ChatCompletionEventListener.java @@ -10,12 +10,12 @@ import okhttp3.sse.EventSource; public class ChatCompletionEventListener implements CompletionEventListener { - private final CallParameters callParameters; + private final ChatCompletionParameters callParameters; private final CompletionResponseEventListener eventListener; private final StringBuilder messageBuilder = new StringBuilder(); public ChatCompletionEventListener( - CallParameters callParameters, + ChatCompletionParameters callParameters, CompletionResponseEventListener eventListener) { this.callParameters = callParameters; this.eventListener = eventListener; diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 63bfe386..3be0329e 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -69,7 +69,7 @@ public final class CompletionRequestService { new OpenAIChatCompletionEventSourceListener(eventListener)); } - public String getLookupCompletion(LookupRequestCallParameters params) { + public String getLookupCompletion(LookupCompletionParameters params) { var request = CompletionRequestFactory .getFactory(GeneralSettings.getSelectedService()) .createLookupRequest(params); @@ -77,7 +77,7 @@ public final class CompletionRequestService { } public EventSource getCommitMessageAsync( - CommitMessageRequestParameters params, + CommitMessageCompletionParameters params, CompletionEventListener eventListener) { var request = CompletionRequestFactory .getFactory(GeneralSettings.getSelectedService()) @@ -86,7 +86,7 @@ public final class CompletionRequestService { } public EventSource getEditCodeCompletionAsync( - EditCodeRequestParameters params, + EditCodeCompletionParameters params, CompletionEventListener eventListener) { var request = CompletionRequestFactory .getFactory(GeneralSettings.getSelectedService()) diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionResponseEventListener.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionResponseEventListener.java index b0538b59..6144f868 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionResponseEventListener.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionResponseEventListener.java @@ -19,7 +19,7 @@ public interface CompletionResponseEventListener { default void handleCompleted(String fullMessage) { } - default void handleCompleted(String fullMessage, CallParameters callParameters) { + default void handleCompleted(String fullMessage, ChatCompletionParameters callParameters) { } default void handleCodeGPTEvent(CodeGPTEvent event) { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java b/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java index 309bb1c4..18200139 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/MethodNameLookupListener.java @@ -57,7 +57,7 @@ public class MethodNameLookupListener implements LookupManagerListener { String prompt) { try { var response = CompletionRequestService.getInstance() - .getLookupCompletion(new LookupRequestCallParameters(prompt)); + .getLookupCompletion(new LookupCompletionParameters(prompt)); if (!response.isEmpty()) { for (var value : response.split(",")) { application.invokeLater(() -> application.runReadAction(() -> { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/ToolwindowChatCompletionRequestHandler.java b/src/main/java/ee/carlrobert/codegpt/completions/ToolwindowChatCompletionRequestHandler.java index d0c125bf..4505bbaa 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/ToolwindowChatCompletionRequestHandler.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/ToolwindowChatCompletionRequestHandler.java @@ -15,7 +15,7 @@ public class ToolwindowChatCompletionRequestHandler { this.completionResponseEventListener = completionResponseEventListener; } - public void call(CallParameters callParameters) { + public void call(ChatCompletionParameters callParameters) { try { eventSource = startCall(callParameters); } catch (TotalUsageExceededException e) { @@ -33,11 +33,11 @@ public class ToolwindowChatCompletionRequestHandler { } } - private EventSource startCall(CallParameters callParameters) { + private EventSource startCall(ChatCompletionParameters callParameters) { try { var request = CompletionRequestFactory .getFactory(GeneralSettings.getSelectedService()) - .createChatRequest(new ChatCompletionRequestParameters(callParameters)); + .createChatRequest(callParameters); return CompletionRequestService.getInstance().getChatCompletionAsync( request, new ChatCompletionEventListener(callParameters, completionResponseEventListener)); @@ -57,7 +57,7 @@ public class ToolwindowChatCompletionRequestHandler { completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex); } - private void sendInfo(CallParameters callParameters) { + private void sendInfo(ChatCompletionParameters callParameters) { TelemetryAction.COMPLETION.createActionMessage() .property("conversationId", callParameters.getConversation().getId().toString()) .property("model", callParameters.getConversation().getModel()) diff --git a/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java b/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java index 603c9776..ea4a2686 100644 --- a/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java +++ b/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java @@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.conversations; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; -import ee.carlrobert.codegpt.completions.CallParameters; +import ee.carlrobert.codegpt.completions.ChatCompletionParameters; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.settings.GeneralSettings; import ee.carlrobert.codegpt.settings.service.ServiceType; @@ -63,11 +63,11 @@ public final class ConversationService { conversationsMapping.put(conversation.getClientCode(), conversations); } - public void saveMessage(String response, CallParameters callParameters) { + public void saveMessage(String response, ChatCompletionParameters callParameters) { var conversation = callParameters.getConversation(); var message = callParameters.getMessage(); var conversationMessages = conversation.getMessages(); - if (callParameters.isRetry() && !conversationMessages.isEmpty()) { + if (callParameters.getRetry() && !conversationMessages.isEmpty()) { var messageToBeSaved = conversationMessages.stream() .filter(item -> item.getId().equals(message.getId())) .findFirst().orElseThrow(); diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java index 4ef67875..eb4ffbea 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java @@ -2,12 +2,10 @@ package ee.carlrobert.codegpt.toolwindow.chat; import static ee.carlrobert.codegpt.ui.UIUtil.createScrollPaneWithSmartScroller; import static java.lang.String.format; -import static java.util.Collections.emptyList; import com.intellij.openapi.Disposable; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.diagnostic.Logger; -import com.intellij.openapi.editor.Editor; import com.intellij.openapi.editor.SelectionModel; import com.intellij.openapi.editor.ex.EditorEx; import com.intellij.openapi.editor.impl.EditorImpl; @@ -17,7 +15,7 @@ import com.intellij.util.ui.JBUI; import ee.carlrobert.codegpt.CodeGPTKeys; import ee.carlrobert.codegpt.ReferencedFile; import ee.carlrobert.codegpt.actions.ActionType; -import ee.carlrobert.codegpt.completions.CallParameters; +import ee.carlrobert.codegpt.completions.ChatCompletionParameters; import ee.carlrobert.codegpt.completions.CompletionRequestService; import ee.carlrobert.codegpt.completions.ConversationType; import ee.carlrobert.codegpt.completions.ToolwindowChatCompletionRequestHandler; @@ -43,7 +41,6 @@ import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.UUID; @@ -121,8 +118,28 @@ public class ChatToolWindowTabPanel implements Disposable { totalTokensPanel.updateConversationTokens(conversation); } - public void sendMessage(Message message) { - sendMessage(message, ConversationType.DEFAULT); + public List getReferencedFiles() { + List referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES); + if (referencedFiles == null) { + return conversation.getMessages().stream() + .flatMap(prevMessage -> { + if (prevMessage.getReferencedFilePaths() != null) { + return prevMessage.getReferencedFilePaths().stream(); + } + return Stream.empty(); + }) + .map(filePath -> { + try { + return new ReferencedFile(new File(filePath)); + } catch (Exception e) { + return null; + } + }) + .filter(Objects::nonNull) + .toList(); + } + + return referencedFiles; } public void sendMessage(Message message, ConversationType conversationType) { @@ -134,79 +151,65 @@ public class ChatToolWindowTabPanel implements Disposable { ConversationType conversationType, @Nullable String highlightedText) { ApplicationManager.getApplication().invokeLater(() -> { - var referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES); - var chatToolWindowPanel = project.getService(ChatToolWindowContentManager.class) - .tryFindChatToolWindowPanel(); - if (referencedFiles != null && !referencedFiles.isEmpty()) { - var referencedFilePaths = referencedFiles.stream() + List referencedFiles = getReferencedFiles(); + if (!referencedFiles.isEmpty()) { + message.setReferencedFilePaths(referencedFiles.stream() .map(ReferencedFile::getFilePath) - .toList(); - message.setReferencedFilePaths(referencedFilePaths); + .toList()); message.setUserMessage(message.getPrompt()); - - chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project)); - } else { - referencedFiles = conversation.getMessages().stream() - .flatMap(prevMessage -> { - if (prevMessage.getReferencedFilePaths() != null) { - return prevMessage.getReferencedFilePaths().stream(); - } - return Stream.empty(); - }) - .map(filePath -> { - try { - return new ReferencedFile(new File(filePath)); - } catch (Exception e) { - return null; - } - }) - .filter(Objects::nonNull) - .toList(); } + + String attachedImagePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project); + if (attachedImagePath != null) { + message.setImageFilePath(attachedImagePath); + } + totalTokensPanel.updateConversationTokens(conversation); totalTokensPanel.updateReferencedFilesTokens(referencedFiles); - var userMessagePanel = new UserMessagePanel(project, message, this); - var attachedFilePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project); - var callParameters = - getCallParameters(conversationType, message, highlightedText, attachedFilePath); - callParameters.setReferencedFiles(referencedFiles); - if (callParameters.getImageData() != null) { - message.setImageFilePath(attachedFilePath); - chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project)); - userMessagePanel.displayImage(attachedFilePath); + if (attachedImagePath != null || !referencedFiles.isEmpty()) { + project.getService(ChatToolWindowContentManager.class) + .tryFindChatToolWindowPanel() + .ifPresent(panel -> panel.clearNotifications(project)); } + var callParameters = getCallParameters( + message, + conversationType, + referencedFiles, + highlightedText, + attachedImagePath); + var responsePanel = createResponsePanel(callParameters); var messagePanel = toolWindowScrollablePanel.addMessage(message.getId()); - messagePanel.add(userMessagePanel); - - var responsePanel = createResponsePanel(callParameters, conversationType); + messagePanel.add(new UserMessagePanel(project, message, this)); messagePanel.add(responsePanel); + call(callParameters, responsePanel); }); } - private CallParameters getCallParameters( - ConversationType conversationType, + private ChatCompletionParameters getCallParameters( Message message, + ConversationType conversationType, + List referencedFiles, @Nullable String highlightedText, - @Nullable String attachedFilePath) { - var callParameters = new CallParameters( - chatSession.getId(), - conversation, - conversationType, - message, - highlightedText, - false); - if (attachedFilePath != null && !attachedFilePath.isEmpty()) { + @Nullable String attachedImagePath) { + var builder = ChatCompletionParameters.builder(conversation, message) + .sessionId(chatSession.getId()) + .conversationType(conversationType) + .highlightedText(highlightedText) + .referencedFiles(referencedFiles); + + if (attachedImagePath != null && !attachedImagePath.isEmpty()) { try { - callParameters.setImageData(Files.readAllBytes(Path.of(attachedFilePath))); - callParameters.setImageMediaType(FileUtil.getImageMediaType(attachedFilePath)); + builder + .imageData(Files.readAllBytes(Path.of(attachedImagePath))) + .imageMediaType(FileUtil.getImageMediaType(attachedImagePath)); } catch (IOException e) { throw new RuntimeException(e); } } - return callParameters; + return builder.build(); } private boolean hasReferencedFilePaths(Message message) { @@ -219,15 +222,13 @@ public class ChatToolWindowTabPanel implements Disposable { it -> it.getReferencedFilePaths() != null && !it.getReferencedFilePaths().isEmpty()); } - private ResponsePanel createResponsePanel( - CallParameters callParameters, - ConversationType conversationType) { + private ResponsePanel createResponsePanel(ChatCompletionParameters callParameters) { var message = callParameters.getMessage(); var fileContextIncluded = hasReferencedFilePaths(message) || hasReferencedFilePaths(conversation); return new ResponsePanel() - .withReloadAction(() -> reloadMessage(message, conversation, conversationType)) + .withReloadAction(() -> reloadMessage(callParameters)) .withDeleteAction(() -> removeMessage(message.getId(), conversation)) .addContent( new ChatMessageResponseBody( @@ -241,31 +242,22 @@ public class ChatToolWindowTabPanel implements Disposable { this)); } - private void reloadMessage( - Message message, - Conversation conversation, - ConversationType conversationType) { + private void reloadMessage(ChatCompletionParameters prevParameters) { + var prevMessage = prevParameters.getMessage(); ResponsePanel responsePanel = null; try { - responsePanel = toolWindowScrollablePanel.getMessageResponsePanel(message.getId()); + responsePanel = toolWindowScrollablePanel.getMessageResponsePanel(prevMessage.getId()); ((ChatMessageResponseBody) responsePanel.getContent()).clear(); toolWindowScrollablePanel.update(); } catch (Exception e) { throw new RuntimeException("Could not delete the existing message component", e); } finally { - LOG.debug("Reloading message: " + message.getId()); + LOG.debug("Reloading message: " + prevMessage.getId()); if (responsePanel != null) { - message.setResponse(""); - conversationService.saveMessage(conversation, message); - call(new CallParameters( - chatSession.getId(), - conversation, - conversationType, - message, - null, - true), - responsePanel); + prevMessage.setResponse(""); + conversationService.saveMessage(conversation, prevMessage); + call(prevParameters.toBuilder().retry(true).build(), responsePanel); } totalTokensPanel.updateConversationTokens(conversation); @@ -292,7 +284,7 @@ public class ChatToolWindowTabPanel implements Disposable { totalTokensPanel.updateConversationTokens(conversation); } - private void call(CallParameters callParameters, ResponsePanel responsePanel) { + private void call(ChatCompletionParameters callParameters, ResponsePanel responsePanel) { var responseContainer = (ChatMessageResponseBody) responsePanel.getContent(); if (!CompletionRequestService.getInstance().isAllowed()) { @@ -316,25 +308,6 @@ public class ChatToolWindowTabPanel implements Disposable { requestHandler.call(callParameters); } - private String processEditorSelection(Editor editor, Message message) { - if (editor == null) { - return null; - } - - SelectionModel selectionModel = editor.getSelectionModel(); - String selectedText = selectionModel.getSelectedText(); - if (selectedText == null || selectedText.isEmpty()) { - return null; - } - - String fileExtension = FileUtil.getFileExtension( - ((EditorEx) editor).getVirtualFile().getName()); - message.setPrompt( - message.getPrompt() + String.format("%n```%s%n%s%n```", fileExtension, selectedText)); - selectionModel.removeSelection(); - return selectedText; - } - private Unit handleSubmit(String text, List appliedInlayActions) { var message = new Message(text); var editor = EditorUtil.getSelectedEditor(project); @@ -430,7 +403,10 @@ public class ChatToolWindowTabPanel implements Disposable { var messagePanel = toolWindowScrollablePanel.addMessage(message.getId()); messagePanel.add(userMessagePanel); messagePanel.add(new ResponsePanel() - .withReloadAction(() -> reloadMessage(message, conversation, ConversationType.DEFAULT)) + .withReloadAction(() -> reloadMessage( + ChatCompletionParameters.builder(conversation, message) + .conversationType(ConversationType.DEFAULT) + .build())) .withDeleteAction(() -> removeMessage(message.getId(), conversation)) .addContent(messageResponseBody)); }); diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java index 72e2f1e9..642c465a 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ToolWindowCompletionResponseEventListener.java @@ -5,7 +5,7 @@ import static com.intellij.openapi.ui.Messages.OK; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.diagnostic.Logger; import ee.carlrobert.codegpt.EncodingManager; -import ee.carlrobert.codegpt.completions.CallParameters; +import ee.carlrobert.codegpt.completions.ChatCompletionParameters; import ee.carlrobert.codegpt.completions.CompletionResponseEventListener; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.ConversationService; @@ -53,16 +53,10 @@ abstract class ToolWindowCompletionResponseEventListener implements @Override public void handleMessage(String partialMessage) { try { - responseContainer.update(partialMessage); messageBuilder.append(partialMessage); - - if (!completed) { - var ongoingTokens = encodingManager.countTokens(messageBuilder.toString()); - ApplicationManager.getApplication().invokeLater(() -> { - totalTokensPanel.update( - totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens); - }); - } + var ongoingTokens = encodingManager.countTokens(messageBuilder.toString()); + responseContainer.updateMessage(partialMessage); + totalTokensPanel.update(totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens); } catch (Exception e) { responseContainer.displayError("Something went wrong."); throw new RuntimeException("Error while updating the content", e); @@ -105,7 +99,7 @@ abstract class ToolWindowCompletionResponseEventListener implements } @Override - public void handleCompleted(String fullMessage, CallParameters callParameters) { + public void handleCompleted(String fullMessage, ChatCompletionParameters callParameters) { conversationService.saveMessage(fullMessage, callParameters); ApplicationManager.getApplication().invokeLater(() -> { diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java index 5aa39c0c..098356fc 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java @@ -8,6 +8,7 @@ import static javax.swing.event.HyperlinkEvent.EventType.ACTIVATED; import com.intellij.icons.AllIcons.General; import com.intellij.openapi.Disposable; import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.fileEditor.FileEditorManager; import com.intellij.openapi.options.ShowSettingsUtil; import com.intellij.openapi.project.Project; @@ -53,6 +54,8 @@ import org.jetbrains.annotations.Nullable; public class ChatMessageResponseBody extends JPanel { + private static final Logger LOG = Logger.getInstance(ChatMessageResponseBody.class); + private final Project project; private final Disposable parentDisposable; private final StreamParser streamParser; @@ -123,14 +126,17 @@ public class ChatMessageResponseBody extends JPanel { } public ChatMessageResponseBody withResponse(String response) { - for (var message : MarkdownUtil.splitCodeBlocks(response)) { - processResponse(message, message.startsWith("```"), false); + try { + for (var message : MarkdownUtil.splitCodeBlocks(response)) { + processResponse(message, message.startsWith("```"), false); + } + } catch (Exception e) { + LOG.error("Something went wrong while processing input", e); } - return this; } - public void update(String partialMessage) { + public void updateMessage(String partialMessage) { for (var item : streamParser.parse(partialMessage)) { processResponse(item.response(), CODE.equals(item.type()), true); } @@ -261,22 +267,24 @@ public class ChatMessageResponseBody extends JPanel { var codeBlock = ((FencedCodeBlock) child); var code = codeBlock.getContentChars().unescape(); if (!code.isEmpty()) { - if (currentlyProcessedEditorPanel == null) { - ApplicationManager.getApplication().invokeAndWait(() -> { + ApplicationManager.getApplication().invokeLater(() -> { + if (currentlyProcessedEditorPanel == null) { prepareProcessingCode(code, codeBlock.getInfo().unescape()); - }); - } - EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code); + } + EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code); + }); } } } private void processText(String markdownText, boolean caretVisible) { var html = convertMdToHtml(markdownText); - if (currentlyProcessedTextPane == null) { - prepareProcessingText(caretVisible); - } - currentlyProcessedTextPane.setText(html); + ApplicationManager.getApplication().invokeLater(() -> { + if (currentlyProcessedTextPane == null) { + prepareProcessingText(caretVisible); + } + currentlyProcessedTextPane.setText(html); + }); } private void prepareProcessingText(boolean caretVisible) { diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/UserMessagePanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/UserMessagePanel.java index ec9ccb52..c7a7775d 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/UserMessagePanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/UserMessagePanel.java @@ -43,6 +43,10 @@ public class UserMessagePanel extends JPanel { add(additionalContextPanel, BorderLayout.CENTER); } + if (message.getImageFilePath() != null && !message.getImageFilePath().isEmpty()) { + displayImage(message.getImageFilePath()); + } + var referencedFilePaths = message.getReferencedFilePaths(); if (referencedFilePaths != null && !referencedFilePaths.isEmpty()) { add(createResponseBody( diff --git a/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt b/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt index ec5c99e1..cf133abf 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt @@ -10,7 +10,7 @@ import com.intellij.util.ui.AsyncProcessIcon import com.intellij.openapi.util.text.StringUtil import com.jetbrains.rd.util.AtomicReference import ee.carlrobert.codegpt.completions.CompletionRequestService -import ee.carlrobert.codegpt.completions.EditCodeRequestParameters +import ee.carlrobert.codegpt.completions.EditCodeCompletionParameters import ee.carlrobert.codegpt.ui.ObservableProperties import javax.swing.JButton @@ -43,7 +43,7 @@ class EditCodeSubmissionHandler( runInEdt { editor.selectionModel.removeSelection() } service().getEditCodeCompletionAsync( - EditCodeRequestParameters(userPrompt, selectedText), + EditCodeCompletionParameters(userPrompt, selectedText), EditCodeCompletionListener( editor, selectionTextRange, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionCallParameters.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionCallParameters.kt deleted file mode 100644 index 53736ad3..00000000 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionCallParameters.kt +++ /dev/null @@ -1,19 +0,0 @@ -package ee.carlrobert.codegpt.completions - -interface CompletionCallParameters - -data class ChatCompletionRequestParameters( - val callParameters: CallParameters -) : CompletionCallParameters - -data class CommitMessageRequestParameters( - val gitDiff: String, - val systemPrompt: String -) : CompletionCallParameters - -data class LookupRequestCallParameters(val prompt: String) : CompletionCallParameters - -data class EditCodeRequestParameters( - val prompt: String, - val selectedText: String -) : CompletionCallParameters \ No newline at end of file diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionParameters.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionParameters.kt new file mode 100644 index 00000000..8f43a232 --- /dev/null +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionParameters.kt @@ -0,0 +1,87 @@ +package ee.carlrobert.codegpt.completions + +import ee.carlrobert.codegpt.ReferencedFile +import ee.carlrobert.codegpt.conversations.Conversation +import ee.carlrobert.codegpt.conversations.message.Message +import java.util.* + +interface CompletionParameters + +class ChatCompletionParameters private constructor( + val conversation: Conversation, + val conversationType: ConversationType, + val message: Message, + var sessionId: UUID?, + var highlightedText: String?, + var retry: Boolean, + var imageMediaType: String?, + var imageData: ByteArray?, + var referencedFiles: List? +) : CompletionParameters { + + fun toBuilder(): Builder { + return Builder(conversation, message).apply { + sessionId(this@ChatCompletionParameters.sessionId) + conversationType(this@ChatCompletionParameters.conversationType) + highlightedText(this@ChatCompletionParameters.highlightedText) + retry(this@ChatCompletionParameters.retry) + imageMediaType(this@ChatCompletionParameters.imageMediaType) + imageData(this@ChatCompletionParameters.imageData) + referencedFiles(this@ChatCompletionParameters.referencedFiles) + } + } + + class Builder(private val conversation: Conversation, private val message: Message) { + private var sessionId: UUID? = null + private var conversationType: ConversationType = ConversationType.DEFAULT + private var highlightedText: String? = null + private var retry: Boolean = false + private var imageMediaType: String? = null + private var imageData: ByteArray? = null + private var referencedFiles: List? = null + + fun sessionId(sessionId: UUID?) = apply { this.sessionId = sessionId } + fun conversationType(conversationType: ConversationType) = + apply { this.conversationType = conversationType } + + fun highlightedText(highlightedText: String?) = + apply { this.highlightedText = highlightedText } + + fun retry(retry: Boolean) = apply { this.retry = retry } + fun imageMediaType(imageMediaType: String?) = apply { this.imageMediaType = imageMediaType } + fun imageData(imageData: ByteArray?) = apply { this.imageData = imageData } + fun referencedFiles(referencedFiles: List?) = + apply { this.referencedFiles = referencedFiles } + + fun build(): ChatCompletionParameters { + return ChatCompletionParameters( + conversation, + conversationType, + message, + sessionId, + highlightedText, + retry, + imageMediaType, + imageData, + referencedFiles + ) + } + } + + companion object { + @JvmStatic + fun builder(conversation: Conversation, message: Message) = Builder(conversation, message) + } +} + +data class CommitMessageCompletionParameters( + val gitDiff: String, + val systemPrompt: String +) : CompletionParameters + +data class LookupCompletionParameters(val prompt: String) : CompletionParameters + +data class EditCodeCompletionParameters( + val prompt: String, + val selectedText: String +) : CompletionParameters \ No newline at end of file diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt index b7b01286..ed0bf3cf 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt @@ -7,10 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.llm.completion.CompletionRequest interface CompletionRequestFactory { - fun createChatRequest(params: ChatCompletionRequestParameters): CompletionRequest - fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest - fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest - fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest + fun createChatRequest(params: ChatCompletionParameters): CompletionRequest + fun createEditCodeRequest(params: EditCodeCompletionParameters): CompletionRequest + fun createCommitMessageRequest(params: CommitMessageCompletionParameters): CompletionRequest + fun createLookupRequest(params: LookupCompletionParameters): CompletionRequest companion object { @JvmStatic @@ -30,16 +30,16 @@ interface CompletionRequestFactory { } abstract class BaseRequestFactory : CompletionRequestFactory { - override fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest { + override fun createEditCodeRequest(params: EditCodeCompletionParameters): CompletionRequest { val prompt = "Code to modify:\n${params.selectedText}\n\nInstructions: ${params.prompt}" return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, 8192, true) } - override fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest { + override fun createCommitMessageRequest(params: CommitMessageCompletionParameters): CompletionRequest { return createBasicCompletionRequest(params.systemPrompt, params.gitDiff, 512, true) } - override fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest { + override fun createLookupRequest(params: LookupCompletionParameters): CompletionRequest { return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, params.prompt, 512) } @@ -50,7 +50,7 @@ abstract class BaseRequestFactory : CompletionRequestFactory { stream: Boolean = false ): CompletionRequest - protected fun getPromptWithFilesContext(callParameters: CallParameters): String { + protected fun getPromptWithFilesContext(callParameters: ChatCompletionParameters): String { return callParameters.referencedFiles?.let { if (it.isEmpty()) { callParameters.message.prompt diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt index 2f8c723f..f916240a 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt @@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters +import ee.carlrobert.codegpt.completions.ChatCompletionParameters import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest @@ -10,12 +10,11 @@ import ee.carlrobert.llm.completion.CompletionRequest class AzureRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest { + override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest { val configuration = service().state - val (callParameters) = params val requestBuilder: OpenAIChatCompletionRequest.Builder = OpenAIChatCompletionRequest.Builder( - buildOpenAIMessages(null, callParameters, callParameters.referencedFiles) + buildOpenAIMessages(null, params, params.referencedFiles) ) .setMaxTokens(configuration.maxTokens) .setStream(true) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt index 78f5509b..df12917d 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt @@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters +import ee.carlrobert.codegpt.completions.ChatCompletionParameters import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.persona.PersonaSettings import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings @@ -11,15 +11,14 @@ import ee.carlrobert.llm.completion.CompletionRequest class ClaudeRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionRequestParameters): ClaudeCompletionRequest { - val (callParameters) = params + override fun createChatRequest(params: ChatCompletionParameters): ClaudeCompletionRequest { return ClaudeCompletionRequest().apply { model = service().state.model maxTokens = service().state.maxTokens isStream = true system = PersonaSettings.getSystemPrompt() - messages = callParameters.conversation.messages + messages = params.conversation.messages .filter { it.response != null && it.response.isNotEmpty() } .flatMap { prevMessage -> sequenceOf( @@ -29,18 +28,15 @@ class ClaudeRequestFactory : BaseRequestFactory() { } when { - callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty() -> { + params.imageMediaType != null && params.imageData != null -> { messages.add( ClaudeCompletionDetailedMessage( "user", listOf( ClaudeMessageImageContent( - ClaudeBase64Source( - callParameters.imageMediaType, - callParameters.imageData - ) + ClaudeBase64Source(params.imageMediaType, params.imageData) ), - ClaudeMessageTextContent(callParameters.message.prompt) + ClaudeMessageTextContent(params.message.prompt) ) ) ) @@ -49,8 +45,7 @@ class ClaudeRequestFactory : BaseRequestFactory() { else -> { messages.add( ClaudeCompletionStandardMessage( - "user", - getPromptWithFilesContext(callParameters) + "user", getPromptWithFilesContext(params) ) ) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt index 44218d09..c3641fc0 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt @@ -4,7 +4,7 @@ import com.intellij.openapi.application.ApplicationInfo import com.intellij.openapi.components.service import ee.carlrobert.codegpt.CodeGPTPlugin import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters +import ee.carlrobert.codegpt.completions.ChatCompletionParameters import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings @@ -13,14 +13,13 @@ import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionSt class CodeGPTRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionRequestParameters): ChatCompletionRequest { - val (callParameters) = params + override fun createChatRequest(params: ChatCompletionParameters): ChatCompletionRequest { val model = service().state.chatCompletionSettings.model val configuration = service().state val requestBuilder: ChatCompletionRequest.Builder = - ChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) + ChatCompletionRequest.Builder(buildOpenAIMessages(model, params)) .setModel(model) - .setSessionId(callParameters.sessionId) + .setSessionId(params.sessionId) .setMetadata( Metadata( CodeGPTPlugin.getVersion(), @@ -40,16 +39,16 @@ class CodeGPTRequestFactory : BaseRequestFactory() { .setTemperature(configuration.temperature.toDouble()) } - if (callParameters.message.isWebSearchIncluded) { + if (params.message.isWebSearchIncluded) { requestBuilder.setWebSearchIncluded(true) } - val documentationDetails = callParameters.message.documentationDetails + val documentationDetails = params.message.documentationDetails if (documentationDetails != null) { requestBuilder.setDocumentationDetails( DocumentationDetails(documentationDetails.name, documentationDetails.url) ) } - callParameters.referencedFiles?.let { + params.referencedFiles?.let { val fileContexts = it.map { file -> ContextFile(file.fileName, file.fileContent) } @@ -81,7 +80,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() { .build() } - fun buildBasicO1Request( + private fun buildBasicO1Request( model: String, prompt: String, systemPrompt: String = "", diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt index 8197d5a2..d47bd644 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt @@ -3,7 +3,7 @@ package ee.carlrobert.codegpt.completions.factory import com.fasterxml.jackson.databind.ObjectMapper import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters +import ee.carlrobert.codegpt.completions.ChatCompletionParameters import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState @@ -19,17 +19,12 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest class CustomOpenAIRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionRequestParameters): CustomOpenAIRequest { - val (callParameters) = params + override fun createChatRequest(params: ChatCompletionParameters): CustomOpenAIRequest { val request = buildCustomOpenAIChatCompletionRequest( service() .state .chatCompletionSettings, - OpenAIRequestFactory.buildOpenAIMessages( - null, - callParameters, - callParameters.referencedFiles - ), + OpenAIRequestFactory.buildOpenAIMessages(null, params, params.referencedFiles), true, getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY) ) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt index d95f2ade..5ab82e4f 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt @@ -20,10 +20,9 @@ import java.nio.file.Path class GoogleRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionRequestParameters): GoogleCompletionRequest { - val (callParameters) = params + override fun createChatRequest(params: ChatCompletionParameters): GoogleCompletionRequest { val configuration = service().state - val messages = buildGoogleMessages(service().state.model, callParameters) + val messages = buildGoogleMessages(service().state.model, params) return GoogleCompletionRequest.Builder(messages) .generationConfig( GoogleGenerationConfig.Builder() @@ -57,9 +56,9 @@ class GoogleRequestFactory : BaseRequestFactory() { private fun buildGoogleMessages( model: String?, - callParameters: CallParameters + params: ChatCompletionParameters ): List { - val messages = buildGoogleMessages(callParameters) + val messages = buildGoogleMessages(params) if (model == null) { return messages @@ -81,7 +80,7 @@ class GoogleRequestFactory : BaseRequestFactory() { } else { tryReducingGoogleMessagesOrThrow( messages, - callParameters.conversation.isDiscardTokenLimit, + params.conversation.isDiscardTokenLimit, totalUsage, googleModel.maxTokens ) @@ -89,11 +88,11 @@ class GoogleRequestFactory : BaseRequestFactory() { } ?: messages } - private fun buildGoogleMessages(callParameters: CallParameters): List { - val message = callParameters.message + private fun buildGoogleMessages(params: ChatCompletionParameters): List { + val message = params.message val messages = mutableListOf() - when (callParameters.conversationType) { + when (params.conversationType) { ConversationType.DEFAULT -> { messages.add( GoogleCompletionContent( @@ -114,8 +113,8 @@ class GoogleRequestFactory : BaseRequestFactory() { else -> {} } - for (prevMessage in callParameters.conversation.messages) { - if (callParameters.isRetry && prevMessage.id == message.id) { + for (prevMessage in params.conversation.messages) { + if (params.retry && prevMessage.id == message.id) { break } @@ -143,15 +142,15 @@ class GoogleRequestFactory : BaseRequestFactory() { messages.add(GoogleCompletionContent("model", listOf(prevMessage.response))) } - if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) { + if (params.imageMediaType != null && params.imageData != null) { messages.add( GoogleCompletionContent( listOf( GoogleContentPart( null, GoogleContentPart.Blob( - callParameters.imageMediaType, - callParameters.imageData + params.imageMediaType, + params.imageData ) ), GoogleContentPart(message.prompt) @@ -162,7 +161,7 @@ class GoogleRequestFactory : BaseRequestFactory() { messages.add( GoogleCompletionContent( "user", - listOf(getPromptWithFilesContext(callParameters)) + listOf(getPromptWithFilesContext(params)) ) ) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt index 085f6f14..e5618ca3 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt @@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters +import ee.carlrobert.codegpt.completions.ChatCompletionParameters import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.ConversationType import ee.carlrobert.codegpt.completions.llama.LlamaModel @@ -14,18 +14,17 @@ import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest class LlamaRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionRequestParameters): LlamaCompletionRequest { - val (callParameters) = params + override fun createChatRequest(params: ChatCompletionParameters): LlamaCompletionRequest { val promptTemplate = getPromptTemplate() val systemPrompt = - if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS) + if (params.conversationType == ConversationType.FIX_COMPILE_ERRORS) FIX_COMPILE_ERRORS_SYSTEM_PROMPT else getSystemPrompt() val prompt = promptTemplate.buildPrompt( systemPrompt, - getPromptWithFilesContext(callParameters), - callParameters.conversation.messages + getPromptWithFilesContext(params), + params.conversation.messages ) return buildLlamaRequest(prompt, promptTemplate.stopTokens, true) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt index 3ecebe3a..2914f595 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt @@ -2,8 +2,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory -import ee.carlrobert.codegpt.completions.CallParameters -import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters +import ee.carlrobert.codegpt.completions.ChatCompletionParameters import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.ConversationType import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings @@ -19,13 +18,12 @@ import java.util.* class OllamaRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionRequestParameters): OllamaChatCompletionRequest { - val (callParameters) = params + override fun createChatRequest(params: ChatCompletionParameters): OllamaChatCompletionRequest { val configuration = service().state val settings = service().state return OllamaChatCompletionRequest.Builder( settings.model, - buildOllamaMessages(callParameters) + buildOllamaMessages(params) ) .setStream(true) .setOptions( @@ -54,11 +52,11 @@ class OllamaRequestFactory : BaseRequestFactory() { .build() } - private fun buildOllamaMessages(callParameters: CallParameters): List { - val message = callParameters.message + private fun buildOllamaMessages(params: ChatCompletionParameters): List { + val message = params.message val messages = mutableListOf() - when (callParameters.conversationType) { + when (params.conversationType) { ConversationType.DEFAULT -> messages.add( OllamaChatCompletionMessage("system", PersonaSettings.getSystemPrompt(), null) ) @@ -70,8 +68,8 @@ class OllamaRequestFactory : BaseRequestFactory() { else -> {} } - for (prevMessage in callParameters.conversation.messages) { - if (callParameters.isRetry && prevMessage.id == message.id) break + for (prevMessage in params.conversation.messages) { + if (params.retry && prevMessage.id == message.id) break prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath -> try { @@ -91,7 +89,7 @@ class OllamaRequestFactory : BaseRequestFactory() { messages.add( OllamaChatCompletionMessage( "user", - getPromptWithFilesContext(callParameters), + getPromptWithFilesContext(params), null ) ) @@ -100,8 +98,8 @@ class OllamaRequestFactory : BaseRequestFactory() { messages.add(OllamaChatCompletionMessage("assistant", prevMessage.response, null)) } - if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) { - val imageBase64 = Base64.getEncoder().encodeToString(callParameters.imageData) + if (params.imageMediaType != null && params.imageData != null) { + val imageBase64 = Base64.getEncoder().encodeToString(params.imageData) messages.add(OllamaChatCompletionMessage("user", message.prompt, listOf(imageBase64))) } else { messages.add(OllamaChatCompletionMessage("user", message.prompt, null)) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt index 3035c683..364547d0 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt @@ -21,13 +21,12 @@ import java.nio.file.Path class OpenAIRequestFactory : CompletionRequestFactory { - override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest { - val (callParameters) = params + override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest { val model = service().state.model val configuration = service().state val requestBuilder: OpenAIChatCompletionRequest.Builder = OpenAIChatCompletionRequest.Builder( - buildOpenAIMessages(model, callParameters, callParameters.referencedFiles) + buildOpenAIMessages(model, params, params.referencedFiles) ) .setModel(model) if ("o1-mini" == model || "o1-preview" == model) { @@ -48,7 +47,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { return requestBuilder.build() } - override fun createEditCodeRequest(params: EditCodeRequestParameters): OpenAIChatCompletionRequest { + override fun createEditCodeRequest(params: EditCodeCompletionParameters): OpenAIChatCompletionRequest { val model = service().state.model val prompt = "Code to modify:\n${params.selectedText}\n\nInstructions: ${params.prompt}" if (model == "o1-mini" || model == "o1-preview") { @@ -57,7 +56,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, model, true) } - override fun createCommitMessageRequest(params: CommitMessageRequestParameters): OpenAIChatCompletionRequest { + override fun createCommitMessageRequest(params: CommitMessageCompletionParameters): OpenAIChatCompletionRequest { val model = service().state.model val (gitDiff, systemPrompt) = params if (model == "o1-mini" || model == "o1-preview") { @@ -66,7 +65,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { return createBasicCompletionRequest(systemPrompt, gitDiff, model, true) } - override fun createLookupRequest(params: LookupRequestCallParameters): OpenAIChatCompletionRequest { + override fun createLookupRequest(params: LookupCompletionParameters): OpenAIChatCompletionRequest { val model = service().state.model val (prompt) = params if (model == "o1-mini" || model == "o1-preview") { @@ -103,7 +102,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { fun buildOpenAIMessages( model: String?, - callParameters: CallParameters, + callParameters: ChatCompletionParameters, referencedFiles: List? = mutableListOf() ): List { val messages = buildOpenAIChatMessages(model, callParameters, referencedFiles) @@ -140,7 +139,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { private fun buildOpenAIChatMessages( model: String?, - callParameters: CallParameters, + callParameters: ChatCompletionParameters, referencedFiles: List? = mutableListOf() ): MutableList { val message = callParameters.message @@ -169,7 +168,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { } for (prevMessage in callParameters.conversation.messages) { - if (callParameters.isRetry && prevMessage.id == message.id) { + if (callParameters.retry && prevMessage.id == message.id) { break } val prevMessageImageFilePath = prevMessage.imageFilePath @@ -203,7 +202,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { ) } - if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) { + if (callParameters.imageMediaType != null && callParameters.imageData != null) { messages.add( OpenAIChatCompletionDetailedMessage( "user", diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt index bf0cab60..94b5ac95 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt @@ -22,12 +22,11 @@ class CompletionRequestProviderTest : IntegrationTest() { val secondMessage = createDummyMessage(250) conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) + val callParameters = ChatCompletionParameters + .builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) + .build() - val request = OpenAIRequestFactory().createChatRequest( - ChatCompletionRequestParameters( - CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) - ) - ) + val request = OpenAIRequestFactory().createChatRequest(callParameters) assertThat(request.messages) .extracting("role", "content") @@ -49,12 +48,11 @@ class CompletionRequestProviderTest : IntegrationTest() { val secondMessage = createDummyMessage(250) conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) + val callParameters = ChatCompletionParameters + .builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) + .build() - val request = OpenAIRequestFactory().createChatRequest( - ChatCompletionRequestParameters( - CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) - ) - ) + val request = OpenAIRequestFactory().createChatRequest(callParameters) assertThat(request.messages) .extracting("role", "content") @@ -76,19 +74,11 @@ class CompletionRequestProviderTest : IntegrationTest() { val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250) conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) + val callParameters = ChatCompletionParameters.builder(conversation, secondMessage) + .retry(true) + .build() - val request = OpenAIRequestFactory().createChatRequest( - ChatCompletionRequestParameters( - CallParameters( - null, - conversation, - ConversationType.DEFAULT, - secondMessage, - null, - true - ) - ) - ) + val request = OpenAIRequestFactory().createChatRequest(callParameters) assertThat(request.messages) .extracting("role", "content") @@ -111,12 +101,11 @@ class CompletionRequestProviderTest : IntegrationTest() { val remainingMessage = createDummyMessage(1000) conversation.addMessage(remainingMessage) conversation.discardTokenLimits() + val callParameters = ChatCompletionParameters + .builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) + .build() - val request = OpenAIRequestFactory().createChatRequest( - ChatCompletionRequestParameters( - CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) - ) - ) + val request = OpenAIRequestFactory().createChatRequest(callParameters) assertThat(request.messages) .extracting("role", "content") @@ -137,9 +126,9 @@ class CompletionRequestProviderTest : IntegrationTest() { assertThrows(TotalUsageExceededException::class.java) { OpenAIRequestFactory().createChatRequest( - ChatCompletionRequestParameters( - CallParameters(conversation, createDummyMessage(100)) - ) + ChatCompletionParameters + .builder(conversation, createDummyMessage(100)) + .build() ) } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt index 87484628..3fc6bc64 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt @@ -16,246 +16,275 @@ import testsupport.IntegrationTest class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() { - fun testOpenAIChatCompletionCall() { - useOpenAIService() - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("TEST_PROMPT") - val conversation = ConversationService.getInstance().startConversation() - val requestHandler = - ToolwindowChatCompletionRequestHandler( - getRequestEventListener(message) - ) - expectOpenAI(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/v1/chat/completions") - assertThat(request.method).isEqualTo("POST") - assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") - assertThat(request.body) - .extracting( - "model", - "messages") - .containsExactly( - "gpt-4", - listOf( - mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), - mapOf("role" to "user", "content" to "TEST_PROMPT"))) - listOf( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) - }) + fun testOpenAIChatCompletionCall() { + useOpenAIService() + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages" + ) + .containsExactly( + "gpt-4", + listOf( + mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), + mapOf("role" to "user", "content" to "TEST_PROMPT") + ) + ) + listOf( + jsonMapResponse( + "choices", + jsonArray(jsonMap("delta", jsonMap("role", "assistant"))) + ), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))) + ) + }) + val requestHandler = + ToolwindowChatCompletionRequestHandler(getRequestEventListener(message)) - requestHandler.call(CallParameters(conversation, message)) + requestHandler.call(ChatCompletionParameters.builder(conversation, message).build()) - waitExpecting { "Hello!" == message.response } - } - - fun testAzureChatCompletionCall() { - useAzureService() - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val conversationService = ConversationService.getInstance() - val prevMessage = Message("TEST_PREV_PROMPT") - prevMessage.response = "TEST_PREV_RESPONSE" - val conversation = conversationService.startConversation() - conversation.addMessage(prevMessage) - conversationService.saveConversation(conversation) - expectAzure(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo( - "/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions") - assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION") - assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY") - assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt") - assertThat(request.body) - .extracting("messages") - .isEqualTo( - listOf( - mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), - mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"), - mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"), - mapOf("role" to "user", "content" to "TEST_PROMPT"))) - listOf( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) - }) - val message = Message("TEST_PROMPT") - val requestHandler = - ToolwindowChatCompletionRequestHandler( - getRequestEventListener(message) - ) - - requestHandler.call(CallParameters(conversation, message)) - - waitExpecting { "Hello!" == message.response } - } - - fun testLlamaChatCompletionCall() { - useLlamaService() - service().state.maxTokens = 99 - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("TEST_PROMPT") - val conversation = ConversationService.getInstance().startConversation() - conversation.addMessage(Message("Ping", "Pong")) - val requestHandler = - ToolwindowChatCompletionRequestHandler( - getRequestEventListener(message) - ) - expectLlama(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/completion") - assertThat(request.body) - .extracting( - "prompt", - "n_predict", - "stream") - .containsExactly( - LLAMA.buildPrompt( - "TEST_SYSTEM_PROMPT", - "TEST_PROMPT", - conversation.messages), - 99, - true) - listOf( - jsonMapResponse("content", "Hel"), - jsonMapResponse("content", "lo!"), - jsonMapResponse( - e("content", ""), - e("stop", true))) - }) - - requestHandler.call(CallParameters(conversation, message)) - - waitExpecting { "Hello!" == message.response } - } - - fun testOllamaChatCompletionCall() { - useOllamaService() - service().state.maxTokens = 99 - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("TEST_PROMPT") - val conversation = ConversationService.getInstance().startConversation() - val requestHandler = - ToolwindowChatCompletionRequestHandler( - getRequestEventListener(message) - ) - expectOllama(NdJsonStreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/api/chat") - assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") - assertThat(request.body) - .extracting( - "model", - "messages", - "options.num_predict", - "stream" - ) - .containsExactly( - HuggingFaceModel.LLAMA_3_8B_Q6_K.code, - listOf( - mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), - mapOf("role" to "user", "content" to "TEST_PROMPT") - ), - 99, - true - ) - listOf( - jsonMapResponse( - e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))), - e("done", false)), - jsonMapResponse( - e("message", jsonMap(e("content", "lo"), e("role", "assistant"))), - e("done", false)), - jsonMapResponse( - e("message", jsonMap(e("content", "!"), e("role", "assistant"))), - e("done", false)), - jsonMapResponse( - e("message", jsonMap(e("content", ""), e("role", "assistant"))), - e("done", true)) - ) - }) - - requestHandler.call(CallParameters(conversation, message)) - - waitExpecting { "Hello!" == message.response } - } - - fun testGoogleChatCompletionCall() { - useGoogleService() - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("TEST_PROMPT") - val conversation = ConversationService.getInstance().startConversation() - val requestHandler = - ToolwindowChatCompletionRequestHandler( - getRequestEventListener(message) - ) - expectGoogle(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent") - assertThat(request.method).isEqualTo("POST") - assertThat(request.uri.query).isEqualTo("key=TEST_API_KEY&alt=sse") - assertThat(request.body) - .extracting("contents") - .isEqualTo( - listOf( - mapOf("parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")), "role" to "user"), - mapOf("parts" to listOf(mapOf("text" to "Understood.")), "role" to "model"), - mapOf("parts" to listOf(mapOf("text" to "TEST_PROMPT")), "role" to "user"), - ) - ) - listOf( - jsonMapResponse( - "candidates", - jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "Hello"))))) - ), - jsonMapResponse( - "candidates", - jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "!"))))) - ) - ) - }) - - requestHandler.call(CallParameters(conversation, message)) - - waitExpecting { "Hello!" == message.response } - } - - fun testCodeGPTServiceChatCompletionCall() { - useCodeGPTService() - service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" - val message = Message("TEST_PROMPT") - val conversation = ConversationService.getInstance().startConversation() - val requestHandler = - ToolwindowChatCompletionRequestHandler( - getRequestEventListener(message) - ) - expectCodeGPT(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/v1/chat/completions") - assertThat(request.method).isEqualTo("POST") - assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") - assertThat(request.body) - .extracting( - "model", - "messages") - .containsExactly( - "TEST_MODEL", - listOf( - mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), - mapOf("role" to "user", "content" to "TEST_PROMPT"))) - listOf( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) - }) - - requestHandler.call(CallParameters(conversation, message)) - - waitExpecting { "Hello!" == message.response } - } - - private fun getRequestEventListener(message: Message): CompletionResponseEventListener { - return object : CompletionResponseEventListener { - override fun handleCompleted(fullMessage: String, callParameters: CallParameters) { - message.response = fullMessage - } + waitExpecting { "Hello!" == message.response } + } + + fun testAzureChatCompletionCall() { + useAzureService() + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val conversationService = ConversationService.getInstance() + val prevMessage = Message("TEST_PREV_PROMPT") + prevMessage.response = "TEST_PREV_RESPONSE" + val conversation = conversationService.startConversation() + conversation.addMessage(prevMessage) + conversationService.saveConversation(conversation) + expectAzure(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo( + "/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions" + ) + assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION") + assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY") + assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt") + assertThat(request.body) + .extracting("messages") + .isEqualTo( + listOf( + mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), + mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"), + mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"), + mapOf("role" to "user", "content" to "TEST_PROMPT") + ) + ) + listOf( + jsonMapResponse( + "choices", + jsonArray(jsonMap("delta", jsonMap("role", "assistant"))) + ), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))) + ) + }) + val message = Message("TEST_PROMPT") + val requestHandler = + ToolwindowChatCompletionRequestHandler(getRequestEventListener(message)) + + requestHandler.call(ChatCompletionParameters.builder(conversation, message).build()) + + waitExpecting { "Hello!" == message.response } + } + + fun testLlamaChatCompletionCall() { + useLlamaService() + service().state.maxTokens = 99 + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + conversation.addMessage(Message("Ping", "Pong")) + expectLlama(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/completion") + assertThat(request.body) + .extracting( + "prompt", + "n_predict", + "stream" + ) + .containsExactly( + LLAMA.buildPrompt( + "TEST_SYSTEM_PROMPT", + "TEST_PROMPT", + conversation.messages + ), + 99, + true + ) + listOf( + jsonMapResponse("content", "Hel"), + jsonMapResponse("content", "lo!"), + jsonMapResponse( + e("content", ""), + e("stop", true) + ) + ) + }) + val requestHandler = + ToolwindowChatCompletionRequestHandler(getRequestEventListener(message)) + + requestHandler.call(ChatCompletionParameters.builder(conversation, message).build()) + + waitExpecting { "Hello!" == message.response } + } + + fun testOllamaChatCompletionCall() { + useOllamaService() + service().state.maxTokens = 99 + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + expectOllama(NdJsonStreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/api/chat") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages", + "options.num_predict", + "stream" + ) + .containsExactly( + HuggingFaceModel.LLAMA_3_8B_Q6_K.code, + listOf( + mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), + mapOf("role" to "user", "content" to "TEST_PROMPT") + ), + 99, + true + ) + listOf( + jsonMapResponse( + e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))), + e("done", false) + ), + jsonMapResponse( + e("message", jsonMap(e("content", "lo"), e("role", "assistant"))), + e("done", false) + ), + jsonMapResponse( + e("message", jsonMap(e("content", "!"), e("role", "assistant"))), + e("done", false) + ), + jsonMapResponse( + e("message", jsonMap(e("content", ""), e("role", "assistant"))), + e("done", true) + ) + ) + }) + val requestHandler = + ToolwindowChatCompletionRequestHandler(getRequestEventListener(message)) + + requestHandler.call(ChatCompletionParameters.builder(conversation, message).build()) + + waitExpecting { "Hello!" == message.response } + } + + fun testGoogleChatCompletionCall() { + useGoogleService() + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + expectGoogle(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent") + assertThat(request.method).isEqualTo("POST") + assertThat(request.uri.query).isEqualTo("key=TEST_API_KEY&alt=sse") + assertThat(request.body) + .extracting("contents") + .isEqualTo( + listOf( + mapOf( + "parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")), + "role" to "user" + ), + mapOf("parts" to listOf(mapOf("text" to "Understood.")), "role" to "model"), + mapOf("parts" to listOf(mapOf("text" to "TEST_PROMPT")), "role" to "user"), + ) + ) + listOf( + jsonMapResponse( + "candidates", + jsonArray( + jsonMap( + "content", + jsonMap("parts", jsonArray(jsonMap("text", "Hello"))) + ) + ) + ), + jsonMapResponse( + "candidates", + jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "!"))))) + ) + ) + }) + val requestHandler = + ToolwindowChatCompletionRequestHandler(getRequestEventListener(message)) + + requestHandler.call(ChatCompletionParameters.builder(conversation, message).build()) + + waitExpecting { "Hello!" == message.response } + } + + fun testCodeGPTServiceChatCompletionCall() { + useCodeGPTService() + service().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + expectCodeGPT(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages" + ) + .containsExactly( + "TEST_MODEL", + listOf( + mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), + mapOf("role" to "user", "content" to "TEST_PROMPT") + ) + ) + listOf( + jsonMapResponse( + "choices", + jsonArray(jsonMap("delta", jsonMap("role", "assistant"))) + ), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))) + ) + }) + val requestHandler = + ToolwindowChatCompletionRequestHandler(getRequestEventListener(message)) + + requestHandler.call(ChatCompletionParameters.builder(conversation, message).build()) + + waitExpecting { "Hello!" == message.response } + } + + private fun getRequestEventListener(message: Message): CompletionResponseEventListener { + return object : CompletionResponseEventListener { + override fun handleCompleted( + fullMessage: String, + callParameters: ChatCompletionParameters + ) { + message.response = fullMessage + } + } } - } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt index 829d9743..9dd617ef 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt @@ -59,7 +59,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { ) }) - panel.sendMessage(message) + panel.sendMessage(message, ConversationType.DEFAULT) waitExpecting { val messages = conversation.messages @@ -161,7 +161,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { ) }) - panel.sendMessage(message) + panel.sendMessage(message, ConversationType.DEFAULT) waitExpecting { val messages = conversation.messages @@ -250,7 +250,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() { ) }) - panel.sendMessage(message) + panel.sendMessage(message, ConversationType.DEFAULT) waitExpecting { val messages = conversation.messages