mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-20 01:02:02 +00:00
refactor: improve chat completion call handling
This commit is contained in:
parent
1b3b5687bc
commit
5ad9bcfaff
27 changed files with 568 additions and 610 deletions
|
|
@ -28,7 +28,7 @@ import com.intellij.openapi.vfs.VirtualFile;
|
|||
import ee.carlrobert.codegpt.CodeGPTBundle;
|
||||
import ee.carlrobert.codegpt.EncodingManager;
|
||||
import ee.carlrobert.codegpt.Icons;
|
||||
import ee.carlrobert.codegpt.completions.CommitMessageRequestParameters;
|
||||
import ee.carlrobert.codegpt.completions.CommitMessageCompletionParameters;
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestService;
|
||||
import ee.carlrobert.codegpt.settings.configuration.CommitMessageTemplate;
|
||||
import ee.carlrobert.codegpt.ui.OverlayUtil;
|
||||
|
|
@ -96,7 +96,7 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
if (editor != null) {
|
||||
((EditorEx) editor).setCaretVisible(false);
|
||||
CompletionRequestService.getInstance().getCommitMessageAsync(
|
||||
new CommitMessageRequestParameters(
|
||||
new CommitMessageCompletionParameters(
|
||||
gitDiff,
|
||||
project.getService(CommitMessageTemplate.class).getSystemPrompt()),
|
||||
getEventListener(project, editor.getDocument()));
|
||||
|
|
|
|||
|
|
@ -1,93 +0,0 @@
|
|||
package ee.carlrobert.codegpt.completions;
|
||||
|
||||
import ee.carlrobert.codegpt.ReferencedFile;
|
||||
import ee.carlrobert.codegpt.conversations.Conversation;
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class CallParameters {
|
||||
|
||||
private final UUID sessionId;
|
||||
private final Conversation conversation;
|
||||
private final ConversationType conversationType;
|
||||
private final Message message;
|
||||
private final boolean retry;
|
||||
private final String highlightedText;
|
||||
private String imageMediaType;
|
||||
private byte[] imageData;
|
||||
private List<ReferencedFile> referencedFiles;
|
||||
|
||||
public CallParameters(Conversation conversation, Message message) {
|
||||
this(null, conversation, message);
|
||||
}
|
||||
|
||||
public CallParameters(UUID sessionId, Conversation conversation, Message message) {
|
||||
this(sessionId, conversation, ConversationType.DEFAULT, message, null, false);
|
||||
}
|
||||
|
||||
// TODO: Builder
|
||||
public CallParameters(
|
||||
UUID sessionId,
|
||||
Conversation conversation,
|
||||
ConversationType conversationType,
|
||||
Message message,
|
||||
@Nullable String highlightedText,
|
||||
boolean retry) {
|
||||
this.sessionId = sessionId;
|
||||
this.conversation = conversation;
|
||||
this.conversationType = conversationType;
|
||||
this.message = message;
|
||||
this.highlightedText = highlightedText;
|
||||
this.retry = retry;
|
||||
}
|
||||
|
||||
public UUID getSessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
public Conversation getConversation() {
|
||||
return conversation;
|
||||
}
|
||||
|
||||
public ConversationType getConversationType() {
|
||||
return conversationType;
|
||||
}
|
||||
|
||||
public Message getMessage() {
|
||||
return message;
|
||||
}
|
||||
|
||||
public boolean isRetry() {
|
||||
return retry;
|
||||
}
|
||||
|
||||
public @Nullable String getImageMediaType() {
|
||||
return imageMediaType;
|
||||
}
|
||||
|
||||
public void setImageMediaType(@Nullable String imageMediaType) {
|
||||
this.imageMediaType = imageMediaType;
|
||||
}
|
||||
|
||||
public byte[] getImageData() {
|
||||
return imageData;
|
||||
}
|
||||
|
||||
public void setImageData(byte[] imageData) {
|
||||
this.imageData = imageData;
|
||||
}
|
||||
|
||||
public @Nullable String getHighlightedText() {
|
||||
return highlightedText;
|
||||
}
|
||||
|
||||
public @Nullable List<ReferencedFile> getReferencedFiles() {
|
||||
return referencedFiles;
|
||||
}
|
||||
|
||||
public void setReferencedFiles(List<ReferencedFile> referencedFiles) {
|
||||
this.referencedFiles = referencedFiles;
|
||||
}
|
||||
}
|
||||
|
|
@ -10,12 +10,12 @@ import okhttp3.sse.EventSource;
|
|||
|
||||
public class ChatCompletionEventListener implements CompletionEventListener<String> {
|
||||
|
||||
private final CallParameters callParameters;
|
||||
private final ChatCompletionParameters callParameters;
|
||||
private final CompletionResponseEventListener eventListener;
|
||||
private final StringBuilder messageBuilder = new StringBuilder();
|
||||
|
||||
public ChatCompletionEventListener(
|
||||
CallParameters callParameters,
|
||||
ChatCompletionParameters callParameters,
|
||||
CompletionResponseEventListener eventListener) {
|
||||
this.callParameters = callParameters;
|
||||
this.eventListener = eventListener;
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ public final class CompletionRequestService {
|
|||
new OpenAIChatCompletionEventSourceListener(eventListener));
|
||||
}
|
||||
|
||||
public String getLookupCompletion(LookupRequestCallParameters params) {
|
||||
public String getLookupCompletion(LookupCompletionParameters params) {
|
||||
var request = CompletionRequestFactory
|
||||
.getFactory(GeneralSettings.getSelectedService())
|
||||
.createLookupRequest(params);
|
||||
|
|
@ -77,7 +77,7 @@ public final class CompletionRequestService {
|
|||
}
|
||||
|
||||
public EventSource getCommitMessageAsync(
|
||||
CommitMessageRequestParameters params,
|
||||
CommitMessageCompletionParameters params,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
var request = CompletionRequestFactory
|
||||
.getFactory(GeneralSettings.getSelectedService())
|
||||
|
|
@ -86,7 +86,7 @@ public final class CompletionRequestService {
|
|||
}
|
||||
|
||||
public EventSource getEditCodeCompletionAsync(
|
||||
EditCodeRequestParameters params,
|
||||
EditCodeCompletionParameters params,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
var request = CompletionRequestFactory
|
||||
.getFactory(GeneralSettings.getSelectedService())
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ public interface CompletionResponseEventListener {
|
|||
default void handleCompleted(String fullMessage) {
|
||||
}
|
||||
|
||||
default void handleCompleted(String fullMessage, CallParameters callParameters) {
|
||||
default void handleCompleted(String fullMessage, ChatCompletionParameters callParameters) {
|
||||
}
|
||||
|
||||
default void handleCodeGPTEvent(CodeGPTEvent event) {
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ public class MethodNameLookupListener implements LookupManagerListener {
|
|||
String prompt) {
|
||||
try {
|
||||
var response = CompletionRequestService.getInstance()
|
||||
.getLookupCompletion(new LookupRequestCallParameters(prompt));
|
||||
.getLookupCompletion(new LookupCompletionParameters(prompt));
|
||||
if (!response.isEmpty()) {
|
||||
for (var value : response.split(",")) {
|
||||
application.invokeLater(() -> application.runReadAction(() -> {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ public class ToolwindowChatCompletionRequestHandler {
|
|||
this.completionResponseEventListener = completionResponseEventListener;
|
||||
}
|
||||
|
||||
public void call(CallParameters callParameters) {
|
||||
public void call(ChatCompletionParameters callParameters) {
|
||||
try {
|
||||
eventSource = startCall(callParameters);
|
||||
} catch (TotalUsageExceededException e) {
|
||||
|
|
@ -33,11 +33,11 @@ public class ToolwindowChatCompletionRequestHandler {
|
|||
}
|
||||
}
|
||||
|
||||
private EventSource startCall(CallParameters callParameters) {
|
||||
private EventSource startCall(ChatCompletionParameters callParameters) {
|
||||
try {
|
||||
var request = CompletionRequestFactory
|
||||
.getFactory(GeneralSettings.getSelectedService())
|
||||
.createChatRequest(new ChatCompletionRequestParameters(callParameters));
|
||||
.createChatRequest(callParameters);
|
||||
return CompletionRequestService.getInstance().getChatCompletionAsync(
|
||||
request,
|
||||
new ChatCompletionEventListener(callParameters, completionResponseEventListener));
|
||||
|
|
@ -57,7 +57,7 @@ public class ToolwindowChatCompletionRequestHandler {
|
|||
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
|
||||
}
|
||||
|
||||
private void sendInfo(CallParameters callParameters) {
|
||||
private void sendInfo(ChatCompletionParameters callParameters) {
|
||||
TelemetryAction.COMPLETION.createActionMessage()
|
||||
.property("conversationId", callParameters.getConversation().getId().toString())
|
||||
.property("model", callParameters.getConversation().getModel())
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.conversations;
|
|||
|
||||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.components.Service;
|
||||
import ee.carlrobert.codegpt.completions.CallParameters;
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters;
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType;
|
||||
|
|
@ -63,11 +63,11 @@ public final class ConversationService {
|
|||
conversationsMapping.put(conversation.getClientCode(), conversations);
|
||||
}
|
||||
|
||||
public void saveMessage(String response, CallParameters callParameters) {
|
||||
public void saveMessage(String response, ChatCompletionParameters callParameters) {
|
||||
var conversation = callParameters.getConversation();
|
||||
var message = callParameters.getMessage();
|
||||
var conversationMessages = conversation.getMessages();
|
||||
if (callParameters.isRetry() && !conversationMessages.isEmpty()) {
|
||||
if (callParameters.getRetry() && !conversationMessages.isEmpty()) {
|
||||
var messageToBeSaved = conversationMessages.stream()
|
||||
.filter(item -> item.getId().equals(message.getId()))
|
||||
.findFirst().orElseThrow();
|
||||
|
|
|
|||
|
|
@ -2,12 +2,10 @@ package ee.carlrobert.codegpt.toolwindow.chat;
|
|||
|
||||
import static ee.carlrobert.codegpt.ui.UIUtil.createScrollPaneWithSmartScroller;
|
||||
import static java.lang.String.format;
|
||||
import static java.util.Collections.emptyList;
|
||||
|
||||
import com.intellij.openapi.Disposable;
|
||||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.diagnostic.Logger;
|
||||
import com.intellij.openapi.editor.Editor;
|
||||
import com.intellij.openapi.editor.SelectionModel;
|
||||
import com.intellij.openapi.editor.ex.EditorEx;
|
||||
import com.intellij.openapi.editor.impl.EditorImpl;
|
||||
|
|
@ -17,7 +15,7 @@ import com.intellij.util.ui.JBUI;
|
|||
import ee.carlrobert.codegpt.CodeGPTKeys;
|
||||
import ee.carlrobert.codegpt.ReferencedFile;
|
||||
import ee.carlrobert.codegpt.actions.ActionType;
|
||||
import ee.carlrobert.codegpt.completions.CallParameters;
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters;
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestService;
|
||||
import ee.carlrobert.codegpt.completions.ConversationType;
|
||||
import ee.carlrobert.codegpt.completions.ToolwindowChatCompletionRequestHandler;
|
||||
|
|
@ -43,7 +41,6 @@ import java.io.File;
|
|||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.UUID;
|
||||
|
|
@ -121,8 +118,28 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
totalTokensPanel.updateConversationTokens(conversation);
|
||||
}
|
||||
|
||||
public void sendMessage(Message message) {
|
||||
sendMessage(message, ConversationType.DEFAULT);
|
||||
public List<ReferencedFile> getReferencedFiles() {
|
||||
List<ReferencedFile> referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES);
|
||||
if (referencedFiles == null) {
|
||||
return conversation.getMessages().stream()
|
||||
.flatMap(prevMessage -> {
|
||||
if (prevMessage.getReferencedFilePaths() != null) {
|
||||
return prevMessage.getReferencedFilePaths().stream();
|
||||
}
|
||||
return Stream.empty();
|
||||
})
|
||||
.map(filePath -> {
|
||||
try {
|
||||
return new ReferencedFile(new File(filePath));
|
||||
} catch (Exception e) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
.toList();
|
||||
}
|
||||
|
||||
return referencedFiles;
|
||||
}
|
||||
|
||||
public void sendMessage(Message message, ConversationType conversationType) {
|
||||
|
|
@ -134,79 +151,65 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
ConversationType conversationType,
|
||||
@Nullable String highlightedText) {
|
||||
ApplicationManager.getApplication().invokeLater(() -> {
|
||||
var referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES);
|
||||
var chatToolWindowPanel = project.getService(ChatToolWindowContentManager.class)
|
||||
.tryFindChatToolWindowPanel();
|
||||
if (referencedFiles != null && !referencedFiles.isEmpty()) {
|
||||
var referencedFilePaths = referencedFiles.stream()
|
||||
List<ReferencedFile> referencedFiles = getReferencedFiles();
|
||||
if (!referencedFiles.isEmpty()) {
|
||||
message.setReferencedFilePaths(referencedFiles.stream()
|
||||
.map(ReferencedFile::getFilePath)
|
||||
.toList();
|
||||
message.setReferencedFilePaths(referencedFilePaths);
|
||||
.toList());
|
||||
message.setUserMessage(message.getPrompt());
|
||||
|
||||
chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project));
|
||||
} else {
|
||||
referencedFiles = conversation.getMessages().stream()
|
||||
.flatMap(prevMessage -> {
|
||||
if (prevMessage.getReferencedFilePaths() != null) {
|
||||
return prevMessage.getReferencedFilePaths().stream();
|
||||
}
|
||||
return Stream.empty();
|
||||
})
|
||||
.map(filePath -> {
|
||||
try {
|
||||
return new ReferencedFile(new File(filePath));
|
||||
} catch (Exception e) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
.toList();
|
||||
}
|
||||
|
||||
String attachedImagePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project);
|
||||
if (attachedImagePath != null) {
|
||||
message.setImageFilePath(attachedImagePath);
|
||||
}
|
||||
|
||||
totalTokensPanel.updateConversationTokens(conversation);
|
||||
totalTokensPanel.updateReferencedFilesTokens(referencedFiles);
|
||||
|
||||
var userMessagePanel = new UserMessagePanel(project, message, this);
|
||||
var attachedFilePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project);
|
||||
var callParameters =
|
||||
getCallParameters(conversationType, message, highlightedText, attachedFilePath);
|
||||
callParameters.setReferencedFiles(referencedFiles);
|
||||
if (callParameters.getImageData() != null) {
|
||||
message.setImageFilePath(attachedFilePath);
|
||||
chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project));
|
||||
userMessagePanel.displayImage(attachedFilePath);
|
||||
if (attachedImagePath != null || !referencedFiles.isEmpty()) {
|
||||
project.getService(ChatToolWindowContentManager.class)
|
||||
.tryFindChatToolWindowPanel()
|
||||
.ifPresent(panel -> panel.clearNotifications(project));
|
||||
}
|
||||
|
||||
var callParameters = getCallParameters(
|
||||
message,
|
||||
conversationType,
|
||||
referencedFiles,
|
||||
highlightedText,
|
||||
attachedImagePath);
|
||||
var responsePanel = createResponsePanel(callParameters);
|
||||
var messagePanel = toolWindowScrollablePanel.addMessage(message.getId());
|
||||
messagePanel.add(userMessagePanel);
|
||||
|
||||
var responsePanel = createResponsePanel(callParameters, conversationType);
|
||||
messagePanel.add(new UserMessagePanel(project, message, this));
|
||||
messagePanel.add(responsePanel);
|
||||
|
||||
call(callParameters, responsePanel);
|
||||
});
|
||||
}
|
||||
|
||||
private CallParameters getCallParameters(
|
||||
ConversationType conversationType,
|
||||
private ChatCompletionParameters getCallParameters(
|
||||
Message message,
|
||||
ConversationType conversationType,
|
||||
List<ReferencedFile> referencedFiles,
|
||||
@Nullable String highlightedText,
|
||||
@Nullable String attachedFilePath) {
|
||||
var callParameters = new CallParameters(
|
||||
chatSession.getId(),
|
||||
conversation,
|
||||
conversationType,
|
||||
message,
|
||||
highlightedText,
|
||||
false);
|
||||
if (attachedFilePath != null && !attachedFilePath.isEmpty()) {
|
||||
@Nullable String attachedImagePath) {
|
||||
var builder = ChatCompletionParameters.builder(conversation, message)
|
||||
.sessionId(chatSession.getId())
|
||||
.conversationType(conversationType)
|
||||
.highlightedText(highlightedText)
|
||||
.referencedFiles(referencedFiles);
|
||||
|
||||
if (attachedImagePath != null && !attachedImagePath.isEmpty()) {
|
||||
try {
|
||||
callParameters.setImageData(Files.readAllBytes(Path.of(attachedFilePath)));
|
||||
callParameters.setImageMediaType(FileUtil.getImageMediaType(attachedFilePath));
|
||||
builder
|
||||
.imageData(Files.readAllBytes(Path.of(attachedImagePath)))
|
||||
.imageMediaType(FileUtil.getImageMediaType(attachedImagePath));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
return callParameters;
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private boolean hasReferencedFilePaths(Message message) {
|
||||
|
|
@ -219,15 +222,13 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
it -> it.getReferencedFilePaths() != null && !it.getReferencedFilePaths().isEmpty());
|
||||
}
|
||||
|
||||
private ResponsePanel createResponsePanel(
|
||||
CallParameters callParameters,
|
||||
ConversationType conversationType) {
|
||||
private ResponsePanel createResponsePanel(ChatCompletionParameters callParameters) {
|
||||
var message = callParameters.getMessage();
|
||||
var fileContextIncluded =
|
||||
hasReferencedFilePaths(message) || hasReferencedFilePaths(conversation);
|
||||
|
||||
return new ResponsePanel()
|
||||
.withReloadAction(() -> reloadMessage(message, conversation, conversationType))
|
||||
.withReloadAction(() -> reloadMessage(callParameters))
|
||||
.withDeleteAction(() -> removeMessage(message.getId(), conversation))
|
||||
.addContent(
|
||||
new ChatMessageResponseBody(
|
||||
|
|
@ -241,31 +242,22 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
this));
|
||||
}
|
||||
|
||||
private void reloadMessage(
|
||||
Message message,
|
||||
Conversation conversation,
|
||||
ConversationType conversationType) {
|
||||
private void reloadMessage(ChatCompletionParameters prevParameters) {
|
||||
var prevMessage = prevParameters.getMessage();
|
||||
ResponsePanel responsePanel = null;
|
||||
try {
|
||||
responsePanel = toolWindowScrollablePanel.getMessageResponsePanel(message.getId());
|
||||
responsePanel = toolWindowScrollablePanel.getMessageResponsePanel(prevMessage.getId());
|
||||
((ChatMessageResponseBody) responsePanel.getContent()).clear();
|
||||
toolWindowScrollablePanel.update();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Could not delete the existing message component", e);
|
||||
} finally {
|
||||
LOG.debug("Reloading message: " + message.getId());
|
||||
LOG.debug("Reloading message: " + prevMessage.getId());
|
||||
|
||||
if (responsePanel != null) {
|
||||
message.setResponse("");
|
||||
conversationService.saveMessage(conversation, message);
|
||||
call(new CallParameters(
|
||||
chatSession.getId(),
|
||||
conversation,
|
||||
conversationType,
|
||||
message,
|
||||
null,
|
||||
true),
|
||||
responsePanel);
|
||||
prevMessage.setResponse("");
|
||||
conversationService.saveMessage(conversation, prevMessage);
|
||||
call(prevParameters.toBuilder().retry(true).build(), responsePanel);
|
||||
}
|
||||
|
||||
totalTokensPanel.updateConversationTokens(conversation);
|
||||
|
|
@ -292,7 +284,7 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
totalTokensPanel.updateConversationTokens(conversation);
|
||||
}
|
||||
|
||||
private void call(CallParameters callParameters, ResponsePanel responsePanel) {
|
||||
private void call(ChatCompletionParameters callParameters, ResponsePanel responsePanel) {
|
||||
var responseContainer = (ChatMessageResponseBody) responsePanel.getContent();
|
||||
|
||||
if (!CompletionRequestService.getInstance().isAllowed()) {
|
||||
|
|
@ -316,25 +308,6 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
requestHandler.call(callParameters);
|
||||
}
|
||||
|
||||
private String processEditorSelection(Editor editor, Message message) {
|
||||
if (editor == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
SelectionModel selectionModel = editor.getSelectionModel();
|
||||
String selectedText = selectionModel.getSelectedText();
|
||||
if (selectedText == null || selectedText.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String fileExtension = FileUtil.getFileExtension(
|
||||
((EditorEx) editor).getVirtualFile().getName());
|
||||
message.setPrompt(
|
||||
message.getPrompt() + String.format("%n```%s%n%s%n```", fileExtension, selectedText));
|
||||
selectionModel.removeSelection();
|
||||
return selectedText;
|
||||
}
|
||||
|
||||
private Unit handleSubmit(String text, List<? extends AppliedActionInlay> appliedInlayActions) {
|
||||
var message = new Message(text);
|
||||
var editor = EditorUtil.getSelectedEditor(project);
|
||||
|
|
@ -430,7 +403,10 @@ public class ChatToolWindowTabPanel implements Disposable {
|
|||
var messagePanel = toolWindowScrollablePanel.addMessage(message.getId());
|
||||
messagePanel.add(userMessagePanel);
|
||||
messagePanel.add(new ResponsePanel()
|
||||
.withReloadAction(() -> reloadMessage(message, conversation, ConversationType.DEFAULT))
|
||||
.withReloadAction(() -> reloadMessage(
|
||||
ChatCompletionParameters.builder(conversation, message)
|
||||
.conversationType(ConversationType.DEFAULT)
|
||||
.build()))
|
||||
.withDeleteAction(() -> removeMessage(message.getId(), conversation))
|
||||
.addContent(messageResponseBody));
|
||||
});
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import static com.intellij.openapi.ui.Messages.OK;
|
|||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.diagnostic.Logger;
|
||||
import ee.carlrobert.codegpt.EncodingManager;
|
||||
import ee.carlrobert.codegpt.completions.CallParameters;
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters;
|
||||
import ee.carlrobert.codegpt.completions.CompletionResponseEventListener;
|
||||
import ee.carlrobert.codegpt.conversations.Conversation;
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService;
|
||||
|
|
@ -53,16 +53,10 @@ abstract class ToolWindowCompletionResponseEventListener implements
|
|||
@Override
|
||||
public void handleMessage(String partialMessage) {
|
||||
try {
|
||||
responseContainer.update(partialMessage);
|
||||
messageBuilder.append(partialMessage);
|
||||
|
||||
if (!completed) {
|
||||
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
|
||||
ApplicationManager.getApplication().invokeLater(() -> {
|
||||
totalTokensPanel.update(
|
||||
totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
|
||||
});
|
||||
}
|
||||
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
|
||||
responseContainer.updateMessage(partialMessage);
|
||||
totalTokensPanel.update(totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
|
||||
} catch (Exception e) {
|
||||
responseContainer.displayError("Something went wrong.");
|
||||
throw new RuntimeException("Error while updating the content", e);
|
||||
|
|
@ -105,7 +99,7 @@ abstract class ToolWindowCompletionResponseEventListener implements
|
|||
}
|
||||
|
||||
@Override
|
||||
public void handleCompleted(String fullMessage, CallParameters callParameters) {
|
||||
public void handleCompleted(String fullMessage, ChatCompletionParameters callParameters) {
|
||||
conversationService.saveMessage(fullMessage, callParameters);
|
||||
|
||||
ApplicationManager.getApplication().invokeLater(() -> {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import static javax.swing.event.HyperlinkEvent.EventType.ACTIVATED;
|
|||
import com.intellij.icons.AllIcons.General;
|
||||
import com.intellij.openapi.Disposable;
|
||||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.diagnostic.Logger;
|
||||
import com.intellij.openapi.fileEditor.FileEditorManager;
|
||||
import com.intellij.openapi.options.ShowSettingsUtil;
|
||||
import com.intellij.openapi.project.Project;
|
||||
|
|
@ -53,6 +54,8 @@ import org.jetbrains.annotations.Nullable;
|
|||
|
||||
public class ChatMessageResponseBody extends JPanel {
|
||||
|
||||
private static final Logger LOG = Logger.getInstance(ChatMessageResponseBody.class);
|
||||
|
||||
private final Project project;
|
||||
private final Disposable parentDisposable;
|
||||
private final StreamParser streamParser;
|
||||
|
|
@ -123,14 +126,17 @@ public class ChatMessageResponseBody extends JPanel {
|
|||
}
|
||||
|
||||
public ChatMessageResponseBody withResponse(String response) {
|
||||
for (var message : MarkdownUtil.splitCodeBlocks(response)) {
|
||||
processResponse(message, message.startsWith("```"), false);
|
||||
try {
|
||||
for (var message : MarkdownUtil.splitCodeBlocks(response)) {
|
||||
processResponse(message, message.startsWith("```"), false);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
LOG.error("Something went wrong while processing input", e);
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
public void update(String partialMessage) {
|
||||
public void updateMessage(String partialMessage) {
|
||||
for (var item : streamParser.parse(partialMessage)) {
|
||||
processResponse(item.response(), CODE.equals(item.type()), true);
|
||||
}
|
||||
|
|
@ -261,22 +267,24 @@ public class ChatMessageResponseBody extends JPanel {
|
|||
var codeBlock = ((FencedCodeBlock) child);
|
||||
var code = codeBlock.getContentChars().unescape();
|
||||
if (!code.isEmpty()) {
|
||||
if (currentlyProcessedEditorPanel == null) {
|
||||
ApplicationManager.getApplication().invokeAndWait(() -> {
|
||||
ApplicationManager.getApplication().invokeLater(() -> {
|
||||
if (currentlyProcessedEditorPanel == null) {
|
||||
prepareProcessingCode(code, codeBlock.getInfo().unescape());
|
||||
});
|
||||
}
|
||||
EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code);
|
||||
}
|
||||
EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void processText(String markdownText, boolean caretVisible) {
|
||||
var html = convertMdToHtml(markdownText);
|
||||
if (currentlyProcessedTextPane == null) {
|
||||
prepareProcessingText(caretVisible);
|
||||
}
|
||||
currentlyProcessedTextPane.setText(html);
|
||||
ApplicationManager.getApplication().invokeLater(() -> {
|
||||
if (currentlyProcessedTextPane == null) {
|
||||
prepareProcessingText(caretVisible);
|
||||
}
|
||||
currentlyProcessedTextPane.setText(html);
|
||||
});
|
||||
}
|
||||
|
||||
private void prepareProcessingText(boolean caretVisible) {
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ public class UserMessagePanel extends JPanel {
|
|||
add(additionalContextPanel, BorderLayout.CENTER);
|
||||
}
|
||||
|
||||
if (message.getImageFilePath() != null && !message.getImageFilePath().isEmpty()) {
|
||||
displayImage(message.getImageFilePath());
|
||||
}
|
||||
|
||||
var referencedFilePaths = message.getReferencedFilePaths();
|
||||
if (referencedFilePaths != null && !referencedFilePaths.isEmpty()) {
|
||||
add(createResponseBody(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import com.intellij.util.ui.AsyncProcessIcon
|
|||
import com.intellij.openapi.util.text.StringUtil
|
||||
import com.jetbrains.rd.util.AtomicReference
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestService
|
||||
import ee.carlrobert.codegpt.completions.EditCodeRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.EditCodeCompletionParameters
|
||||
import ee.carlrobert.codegpt.ui.ObservableProperties
|
||||
import javax.swing.JButton
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ class EditCodeSubmissionHandler(
|
|||
runInEdt { editor.selectionModel.removeSelection() }
|
||||
|
||||
service<CompletionRequestService>().getEditCodeCompletionAsync(
|
||||
EditCodeRequestParameters(userPrompt, selectedText),
|
||||
EditCodeCompletionParameters(userPrompt, selectedText),
|
||||
EditCodeCompletionListener(
|
||||
editor,
|
||||
selectionTextRange,
|
||||
|
|
|
|||
|
|
@ -1,19 +0,0 @@
|
|||
package ee.carlrobert.codegpt.completions
|
||||
|
||||
interface CompletionCallParameters
|
||||
|
||||
data class ChatCompletionRequestParameters(
|
||||
val callParameters: CallParameters
|
||||
) : CompletionCallParameters
|
||||
|
||||
data class CommitMessageRequestParameters(
|
||||
val gitDiff: String,
|
||||
val systemPrompt: String
|
||||
) : CompletionCallParameters
|
||||
|
||||
data class LookupRequestCallParameters(val prompt: String) : CompletionCallParameters
|
||||
|
||||
data class EditCodeRequestParameters(
|
||||
val prompt: String,
|
||||
val selectedText: String
|
||||
) : CompletionCallParameters
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
package ee.carlrobert.codegpt.completions
|
||||
|
||||
import ee.carlrobert.codegpt.ReferencedFile
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import java.util.*
|
||||
|
||||
interface CompletionParameters
|
||||
|
||||
class ChatCompletionParameters private constructor(
|
||||
val conversation: Conversation,
|
||||
val conversationType: ConversationType,
|
||||
val message: Message,
|
||||
var sessionId: UUID?,
|
||||
var highlightedText: String?,
|
||||
var retry: Boolean,
|
||||
var imageMediaType: String?,
|
||||
var imageData: ByteArray?,
|
||||
var referencedFiles: List<ReferencedFile>?
|
||||
) : CompletionParameters {
|
||||
|
||||
fun toBuilder(): Builder {
|
||||
return Builder(conversation, message).apply {
|
||||
sessionId(this@ChatCompletionParameters.sessionId)
|
||||
conversationType(this@ChatCompletionParameters.conversationType)
|
||||
highlightedText(this@ChatCompletionParameters.highlightedText)
|
||||
retry(this@ChatCompletionParameters.retry)
|
||||
imageMediaType(this@ChatCompletionParameters.imageMediaType)
|
||||
imageData(this@ChatCompletionParameters.imageData)
|
||||
referencedFiles(this@ChatCompletionParameters.referencedFiles)
|
||||
}
|
||||
}
|
||||
|
||||
class Builder(private val conversation: Conversation, private val message: Message) {
|
||||
private var sessionId: UUID? = null
|
||||
private var conversationType: ConversationType = ConversationType.DEFAULT
|
||||
private var highlightedText: String? = null
|
||||
private var retry: Boolean = false
|
||||
private var imageMediaType: String? = null
|
||||
private var imageData: ByteArray? = null
|
||||
private var referencedFiles: List<ReferencedFile>? = null
|
||||
|
||||
fun sessionId(sessionId: UUID?) = apply { this.sessionId = sessionId }
|
||||
fun conversationType(conversationType: ConversationType) =
|
||||
apply { this.conversationType = conversationType }
|
||||
|
||||
fun highlightedText(highlightedText: String?) =
|
||||
apply { this.highlightedText = highlightedText }
|
||||
|
||||
fun retry(retry: Boolean) = apply { this.retry = retry }
|
||||
fun imageMediaType(imageMediaType: String?) = apply { this.imageMediaType = imageMediaType }
|
||||
fun imageData(imageData: ByteArray?) = apply { this.imageData = imageData }
|
||||
fun referencedFiles(referencedFiles: List<ReferencedFile>?) =
|
||||
apply { this.referencedFiles = referencedFiles }
|
||||
|
||||
fun build(): ChatCompletionParameters {
|
||||
return ChatCompletionParameters(
|
||||
conversation,
|
||||
conversationType,
|
||||
message,
|
||||
sessionId,
|
||||
highlightedText,
|
||||
retry,
|
||||
imageMediaType,
|
||||
imageData,
|
||||
referencedFiles
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
@JvmStatic
|
||||
fun builder(conversation: Conversation, message: Message) = Builder(conversation, message)
|
||||
}
|
||||
}
|
||||
|
||||
data class CommitMessageCompletionParameters(
|
||||
val gitDiff: String,
|
||||
val systemPrompt: String
|
||||
) : CompletionParameters
|
||||
|
||||
data class LookupCompletionParameters(val prompt: String) : CompletionParameters
|
||||
|
||||
data class EditCodeCompletionParameters(
|
||||
val prompt: String,
|
||||
val selectedText: String
|
||||
) : CompletionParameters
|
||||
|
|
@ -7,10 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
|
|||
import ee.carlrobert.llm.completion.CompletionRequest
|
||||
|
||||
interface CompletionRequestFactory {
|
||||
fun createChatRequest(params: ChatCompletionRequestParameters): CompletionRequest
|
||||
fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest
|
||||
fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest
|
||||
fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest
|
||||
fun createChatRequest(params: ChatCompletionParameters): CompletionRequest
|
||||
fun createEditCodeRequest(params: EditCodeCompletionParameters): CompletionRequest
|
||||
fun createCommitMessageRequest(params: CommitMessageCompletionParameters): CompletionRequest
|
||||
fun createLookupRequest(params: LookupCompletionParameters): CompletionRequest
|
||||
|
||||
companion object {
|
||||
@JvmStatic
|
||||
|
|
@ -30,16 +30,16 @@ interface CompletionRequestFactory {
|
|||
}
|
||||
|
||||
abstract class BaseRequestFactory : CompletionRequestFactory {
|
||||
override fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest {
|
||||
override fun createEditCodeRequest(params: EditCodeCompletionParameters): CompletionRequest {
|
||||
val prompt = "Code to modify:\n${params.selectedText}\n\nInstructions: ${params.prompt}"
|
||||
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, 8192, true)
|
||||
}
|
||||
|
||||
override fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest {
|
||||
override fun createCommitMessageRequest(params: CommitMessageCompletionParameters): CompletionRequest {
|
||||
return createBasicCompletionRequest(params.systemPrompt, params.gitDiff, 512, true)
|
||||
}
|
||||
|
||||
override fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest {
|
||||
override fun createLookupRequest(params: LookupCompletionParameters): CompletionRequest {
|
||||
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, params.prompt, 512)
|
||||
}
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ abstract class BaseRequestFactory : CompletionRequestFactory {
|
|||
stream: Boolean = false
|
||||
): CompletionRequest
|
||||
|
||||
protected fun getPromptWithFilesContext(callParameters: CallParameters): String {
|
||||
protected fun getPromptWithFilesContext(callParameters: ChatCompletionParameters): String {
|
||||
return callParameters.referencedFiles?.let {
|
||||
if (it.isEmpty()) {
|
||||
callParameters.message.prompt
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
|
||||
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest
|
||||
|
|
@ -10,12 +10,11 @@ import ee.carlrobert.llm.completion.CompletionRequest
|
|||
|
||||
class AzureRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest {
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val (callParameters) = params
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
OpenAIChatCompletionRequest.Builder(
|
||||
buildOpenAIMessages(null, callParameters, callParameters.referencedFiles)
|
||||
buildOpenAIMessages(null, params, params.referencedFiles)
|
||||
)
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
.setStream(true)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
|
||||
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings
|
||||
|
|
@ -11,15 +11,14 @@ import ee.carlrobert.llm.completion.CompletionRequest
|
|||
|
||||
class ClaudeRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): ClaudeCompletionRequest {
|
||||
val (callParameters) = params
|
||||
override fun createChatRequest(params: ChatCompletionParameters): ClaudeCompletionRequest {
|
||||
return ClaudeCompletionRequest().apply {
|
||||
model = service<AnthropicSettings>().state.model
|
||||
maxTokens = service<ConfigurationSettings>().state.maxTokens
|
||||
isStream = true
|
||||
system = PersonaSettings.getSystemPrompt()
|
||||
|
||||
messages = callParameters.conversation.messages
|
||||
messages = params.conversation.messages
|
||||
.filter { it.response != null && it.response.isNotEmpty() }
|
||||
.flatMap { prevMessage ->
|
||||
sequenceOf(
|
||||
|
|
@ -29,18 +28,15 @@ class ClaudeRequestFactory : BaseRequestFactory() {
|
|||
}
|
||||
|
||||
when {
|
||||
callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty() -> {
|
||||
params.imageMediaType != null && params.imageData != null -> {
|
||||
messages.add(
|
||||
ClaudeCompletionDetailedMessage(
|
||||
"user",
|
||||
listOf(
|
||||
ClaudeMessageImageContent(
|
||||
ClaudeBase64Source(
|
||||
callParameters.imageMediaType,
|
||||
callParameters.imageData
|
||||
)
|
||||
ClaudeBase64Source(params.imageMediaType, params.imageData)
|
||||
),
|
||||
ClaudeMessageTextContent(callParameters.message.prompt)
|
||||
ClaudeMessageTextContent(params.message.prompt)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -49,8 +45,7 @@ class ClaudeRequestFactory : BaseRequestFactory() {
|
|||
else -> {
|
||||
messages.add(
|
||||
ClaudeCompletionStandardMessage(
|
||||
"user",
|
||||
getPromptWithFilesContext(callParameters)
|
||||
"user", getPromptWithFilesContext(params)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import com.intellij.openapi.application.ApplicationInfo
|
|||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.CodeGPTPlugin
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
|
||||
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
|
||||
|
|
@ -13,14 +13,13 @@ import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionSt
|
|||
|
||||
class CodeGPTRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): ChatCompletionRequest {
|
||||
val (callParameters) = params
|
||||
override fun createChatRequest(params: ChatCompletionParameters): ChatCompletionRequest {
|
||||
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val requestBuilder: ChatCompletionRequest.Builder =
|
||||
ChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
|
||||
ChatCompletionRequest.Builder(buildOpenAIMessages(model, params))
|
||||
.setModel(model)
|
||||
.setSessionId(callParameters.sessionId)
|
||||
.setSessionId(params.sessionId)
|
||||
.setMetadata(
|
||||
Metadata(
|
||||
CodeGPTPlugin.getVersion(),
|
||||
|
|
@ -40,16 +39,16 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
|
|||
.setTemperature(configuration.temperature.toDouble())
|
||||
}
|
||||
|
||||
if (callParameters.message.isWebSearchIncluded) {
|
||||
if (params.message.isWebSearchIncluded) {
|
||||
requestBuilder.setWebSearchIncluded(true)
|
||||
}
|
||||
val documentationDetails = callParameters.message.documentationDetails
|
||||
val documentationDetails = params.message.documentationDetails
|
||||
if (documentationDetails != null) {
|
||||
requestBuilder.setDocumentationDetails(
|
||||
DocumentationDetails(documentationDetails.name, documentationDetails.url)
|
||||
)
|
||||
}
|
||||
callParameters.referencedFiles?.let {
|
||||
params.referencedFiles?.let {
|
||||
val fileContexts = it.map { file ->
|
||||
ContextFile(file.fileName, file.fileContent)
|
||||
}
|
||||
|
|
@ -81,7 +80,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
|
|||
.build()
|
||||
}
|
||||
|
||||
fun buildBasicO1Request(
|
||||
private fun buildBasicO1Request(
|
||||
model: String,
|
||||
prompt: String,
|
||||
systemPrompt: String = "",
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
|
|
@ -19,17 +19,12 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest
|
|||
|
||||
class CustomOpenAIRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): CustomOpenAIRequest {
|
||||
val (callParameters) = params
|
||||
override fun createChatRequest(params: ChatCompletionParameters): CustomOpenAIRequest {
|
||||
val request = buildCustomOpenAIChatCompletionRequest(
|
||||
service<CustomServiceSettings>()
|
||||
.state
|
||||
.chatCompletionSettings,
|
||||
OpenAIRequestFactory.buildOpenAIMessages(
|
||||
null,
|
||||
callParameters,
|
||||
callParameters.referencedFiles
|
||||
),
|
||||
OpenAIRequestFactory.buildOpenAIMessages(null, params, params.referencedFiles),
|
||||
true,
|
||||
getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,10 +20,9 @@ import java.nio.file.Path
|
|||
|
||||
class GoogleRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): GoogleCompletionRequest {
|
||||
val (callParameters) = params
|
||||
override fun createChatRequest(params: ChatCompletionParameters): GoogleCompletionRequest {
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val messages = buildGoogleMessages(service<GoogleSettings>().state.model, callParameters)
|
||||
val messages = buildGoogleMessages(service<GoogleSettings>().state.model, params)
|
||||
return GoogleCompletionRequest.Builder(messages)
|
||||
.generationConfig(
|
||||
GoogleGenerationConfig.Builder()
|
||||
|
|
@ -57,9 +56,9 @@ class GoogleRequestFactory : BaseRequestFactory() {
|
|||
|
||||
private fun buildGoogleMessages(
|
||||
model: String?,
|
||||
callParameters: CallParameters
|
||||
params: ChatCompletionParameters
|
||||
): List<GoogleCompletionContent> {
|
||||
val messages = buildGoogleMessages(callParameters)
|
||||
val messages = buildGoogleMessages(params)
|
||||
|
||||
if (model == null) {
|
||||
return messages
|
||||
|
|
@ -81,7 +80,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
|
|||
} else {
|
||||
tryReducingGoogleMessagesOrThrow(
|
||||
messages,
|
||||
callParameters.conversation.isDiscardTokenLimit,
|
||||
params.conversation.isDiscardTokenLimit,
|
||||
totalUsage,
|
||||
googleModel.maxTokens
|
||||
)
|
||||
|
|
@ -89,11 +88,11 @@ class GoogleRequestFactory : BaseRequestFactory() {
|
|||
} ?: messages
|
||||
}
|
||||
|
||||
private fun buildGoogleMessages(callParameters: CallParameters): List<GoogleCompletionContent> {
|
||||
val message = callParameters.message
|
||||
private fun buildGoogleMessages(params: ChatCompletionParameters): List<GoogleCompletionContent> {
|
||||
val message = params.message
|
||||
val messages = mutableListOf<GoogleCompletionContent>()
|
||||
|
||||
when (callParameters.conversationType) {
|
||||
when (params.conversationType) {
|
||||
ConversationType.DEFAULT -> {
|
||||
messages.add(
|
||||
GoogleCompletionContent(
|
||||
|
|
@ -114,8 +113,8 @@ class GoogleRequestFactory : BaseRequestFactory() {
|
|||
else -> {}
|
||||
}
|
||||
|
||||
for (prevMessage in callParameters.conversation.messages) {
|
||||
if (callParameters.isRetry && prevMessage.id == message.id) {
|
||||
for (prevMessage in params.conversation.messages) {
|
||||
if (params.retry && prevMessage.id == message.id) {
|
||||
break
|
||||
}
|
||||
|
||||
|
|
@ -143,15 +142,15 @@ class GoogleRequestFactory : BaseRequestFactory() {
|
|||
messages.add(GoogleCompletionContent("model", listOf(prevMessage.response)))
|
||||
}
|
||||
|
||||
if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) {
|
||||
if (params.imageMediaType != null && params.imageData != null) {
|
||||
messages.add(
|
||||
GoogleCompletionContent(
|
||||
listOf(
|
||||
GoogleContentPart(
|
||||
null,
|
||||
GoogleContentPart.Blob(
|
||||
callParameters.imageMediaType,
|
||||
callParameters.imageData
|
||||
params.imageMediaType,
|
||||
params.imageData
|
||||
)
|
||||
),
|
||||
GoogleContentPart(message.prompt)
|
||||
|
|
@ -162,7 +161,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
|
|||
messages.add(
|
||||
GoogleCompletionContent(
|
||||
"user",
|
||||
listOf(getPromptWithFilesContext(callParameters))
|
||||
listOf(getPromptWithFilesContext(params))
|
||||
)
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.completions.llama.LlamaModel
|
||||
|
|
@ -14,18 +14,17 @@ import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
|
|||
|
||||
class LlamaRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): LlamaCompletionRequest {
|
||||
val (callParameters) = params
|
||||
override fun createChatRequest(params: ChatCompletionParameters): LlamaCompletionRequest {
|
||||
val promptTemplate = getPromptTemplate()
|
||||
val systemPrompt =
|
||||
if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS)
|
||||
if (params.conversationType == ConversationType.FIX_COMPILE_ERRORS)
|
||||
FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
else
|
||||
getSystemPrompt()
|
||||
val prompt = promptTemplate.buildPrompt(
|
||||
systemPrompt,
|
||||
getPromptWithFilesContext(callParameters),
|
||||
callParameters.conversation.messages
|
||||
getPromptWithFilesContext(params),
|
||||
params.conversation.messages
|
||||
)
|
||||
|
||||
return buildLlamaRequest(prompt, promptTemplate.stopTokens, true)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.CallParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
|
|
@ -19,13 +18,12 @@ import java.util.*
|
|||
|
||||
class OllamaRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): OllamaChatCompletionRequest {
|
||||
val (callParameters) = params
|
||||
override fun createChatRequest(params: ChatCompletionParameters): OllamaChatCompletionRequest {
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val settings = service<OllamaSettings>().state
|
||||
return OllamaChatCompletionRequest.Builder(
|
||||
settings.model,
|
||||
buildOllamaMessages(callParameters)
|
||||
buildOllamaMessages(params)
|
||||
)
|
||||
.setStream(true)
|
||||
.setOptions(
|
||||
|
|
@ -54,11 +52,11 @@ class OllamaRequestFactory : BaseRequestFactory() {
|
|||
.build()
|
||||
}
|
||||
|
||||
private fun buildOllamaMessages(callParameters: CallParameters): List<OllamaChatCompletionMessage> {
|
||||
val message = callParameters.message
|
||||
private fun buildOllamaMessages(params: ChatCompletionParameters): List<OllamaChatCompletionMessage> {
|
||||
val message = params.message
|
||||
val messages = mutableListOf<OllamaChatCompletionMessage>()
|
||||
|
||||
when (callParameters.conversationType) {
|
||||
when (params.conversationType) {
|
||||
ConversationType.DEFAULT -> messages.add(
|
||||
OllamaChatCompletionMessage("system", PersonaSettings.getSystemPrompt(), null)
|
||||
)
|
||||
|
|
@ -70,8 +68,8 @@ class OllamaRequestFactory : BaseRequestFactory() {
|
|||
else -> {}
|
||||
}
|
||||
|
||||
for (prevMessage in callParameters.conversation.messages) {
|
||||
if (callParameters.isRetry && prevMessage.id == message.id) break
|
||||
for (prevMessage in params.conversation.messages) {
|
||||
if (params.retry && prevMessage.id == message.id) break
|
||||
|
||||
prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath ->
|
||||
try {
|
||||
|
|
@ -91,7 +89,7 @@ class OllamaRequestFactory : BaseRequestFactory() {
|
|||
messages.add(
|
||||
OllamaChatCompletionMessage(
|
||||
"user",
|
||||
getPromptWithFilesContext(callParameters),
|
||||
getPromptWithFilesContext(params),
|
||||
null
|
||||
)
|
||||
)
|
||||
|
|
@ -100,8 +98,8 @@ class OllamaRequestFactory : BaseRequestFactory() {
|
|||
messages.add(OllamaChatCompletionMessage("assistant", prevMessage.response, null))
|
||||
}
|
||||
|
||||
if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) {
|
||||
val imageBase64 = Base64.getEncoder().encodeToString(callParameters.imageData)
|
||||
if (params.imageMediaType != null && params.imageData != null) {
|
||||
val imageBase64 = Base64.getEncoder().encodeToString(params.imageData)
|
||||
messages.add(OllamaChatCompletionMessage("user", message.prompt, listOf(imageBase64)))
|
||||
} else {
|
||||
messages.add(OllamaChatCompletionMessage("user", message.prompt, null))
|
||||
|
|
|
|||
|
|
@ -21,13 +21,12 @@ import java.nio.file.Path
|
|||
|
||||
class OpenAIRequestFactory : CompletionRequestFactory {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
|
||||
val (callParameters) = params
|
||||
override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OpenAISettings>().state.model
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
OpenAIChatCompletionRequest.Builder(
|
||||
buildOpenAIMessages(model, callParameters, callParameters.referencedFiles)
|
||||
buildOpenAIMessages(model, params, params.referencedFiles)
|
||||
)
|
||||
.setModel(model)
|
||||
if ("o1-mini" == model || "o1-preview" == model) {
|
||||
|
|
@ -48,7 +47,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
return requestBuilder.build()
|
||||
}
|
||||
|
||||
override fun createEditCodeRequest(params: EditCodeRequestParameters): OpenAIChatCompletionRequest {
|
||||
override fun createEditCodeRequest(params: EditCodeCompletionParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OpenAISettings>().state.model
|
||||
val prompt = "Code to modify:\n${params.selectedText}\n\nInstructions: ${params.prompt}"
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
|
|
@ -57,7 +56,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, model, true)
|
||||
}
|
||||
|
||||
override fun createCommitMessageRequest(params: CommitMessageRequestParameters): OpenAIChatCompletionRequest {
|
||||
override fun createCommitMessageRequest(params: CommitMessageCompletionParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OpenAISettings>().state.model
|
||||
val (gitDiff, systemPrompt) = params
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
|
|
@ -66,7 +65,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
return createBasicCompletionRequest(systemPrompt, gitDiff, model, true)
|
||||
}
|
||||
|
||||
override fun createLookupRequest(params: LookupRequestCallParameters): OpenAIChatCompletionRequest {
|
||||
override fun createLookupRequest(params: LookupCompletionParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OpenAISettings>().state.model
|
||||
val (prompt) = params
|
||||
if (model == "o1-mini" || model == "o1-preview") {
|
||||
|
|
@ -103,7 +102,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
|
||||
fun buildOpenAIMessages(
|
||||
model: String?,
|
||||
callParameters: CallParameters,
|
||||
callParameters: ChatCompletionParameters,
|
||||
referencedFiles: List<ReferencedFile>? = mutableListOf()
|
||||
): List<OpenAIChatCompletionMessage> {
|
||||
val messages = buildOpenAIChatMessages(model, callParameters, referencedFiles)
|
||||
|
|
@ -140,7 +139,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
|
||||
private fun buildOpenAIChatMessages(
|
||||
model: String?,
|
||||
callParameters: CallParameters,
|
||||
callParameters: ChatCompletionParameters,
|
||||
referencedFiles: List<ReferencedFile>? = mutableListOf()
|
||||
): MutableList<OpenAIChatCompletionMessage> {
|
||||
val message = callParameters.message
|
||||
|
|
@ -169,7 +168,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
}
|
||||
|
||||
for (prevMessage in callParameters.conversation.messages) {
|
||||
if (callParameters.isRetry && prevMessage.id == message.id) {
|
||||
if (callParameters.retry && prevMessage.id == message.id) {
|
||||
break
|
||||
}
|
||||
val prevMessageImageFilePath = prevMessage.imageFilePath
|
||||
|
|
@ -203,7 +202,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
)
|
||||
}
|
||||
|
||||
if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) {
|
||||
if (callParameters.imageMediaType != null && callParameters.imageData != null) {
|
||||
messages.add(
|
||||
OpenAIChatCompletionDetailedMessage(
|
||||
"user",
|
||||
|
|
|
|||
|
|
@ -22,12 +22,11 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
val secondMessage = createDummyMessage(250)
|
||||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
val callParameters = ChatCompletionParameters
|
||||
.builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
|
||||
.build()
|
||||
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
|
||||
)
|
||||
)
|
||||
val request = OpenAIRequestFactory().createChatRequest(callParameters)
|
||||
|
||||
assertThat(request.messages)
|
||||
.extracting("role", "content")
|
||||
|
|
@ -49,12 +48,11 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
val secondMessage = createDummyMessage(250)
|
||||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
val callParameters = ChatCompletionParameters
|
||||
.builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
|
||||
.build()
|
||||
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
|
||||
)
|
||||
)
|
||||
val request = OpenAIRequestFactory().createChatRequest(callParameters)
|
||||
|
||||
assertThat(request.messages)
|
||||
.extracting("role", "content")
|
||||
|
|
@ -76,19 +74,11 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250)
|
||||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
val callParameters = ChatCompletionParameters.builder(conversation, secondMessage)
|
||||
.retry(true)
|
||||
.build()
|
||||
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(
|
||||
null,
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
secondMessage,
|
||||
null,
|
||||
true
|
||||
)
|
||||
)
|
||||
)
|
||||
val request = OpenAIRequestFactory().createChatRequest(callParameters)
|
||||
|
||||
assertThat(request.messages)
|
||||
.extracting("role", "content")
|
||||
|
|
@ -111,12 +101,11 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
val remainingMessage = createDummyMessage(1000)
|
||||
conversation.addMessage(remainingMessage)
|
||||
conversation.discardTokenLimits()
|
||||
val callParameters = ChatCompletionParameters
|
||||
.builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
|
||||
.build()
|
||||
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
|
||||
)
|
||||
)
|
||||
val request = OpenAIRequestFactory().createChatRequest(callParameters)
|
||||
|
||||
assertThat(request.messages)
|
||||
.extracting("role", "content")
|
||||
|
|
@ -137,9 +126,9 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
|
||||
assertThrows(TotalUsageExceededException::class.java) {
|
||||
OpenAIRequestFactory().createChatRequest(
|
||||
ChatCompletionRequestParameters(
|
||||
CallParameters(conversation, createDummyMessage(100))
|
||||
)
|
||||
ChatCompletionParameters
|
||||
.builder(conversation, createDummyMessage(100))
|
||||
.build()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,246 +16,275 @@ import testsupport.IntegrationTest
|
|||
|
||||
class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
|
||||
|
||||
fun testOpenAIChatCompletionCall() {
|
||||
useOpenAIService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")))
|
||||
listOf(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
fun testOpenAIChatCompletionCall() {
|
||||
useOpenAIService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages"
|
||||
)
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")
|
||||
)
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse(
|
||||
"choices",
|
||||
jsonArray(jsonMap("delta", jsonMap("role", "assistant")))
|
||||
),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))
|
||||
)
|
||||
})
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
|
||||
|
||||
requestHandler.call(CallParameters(conversation, message))
|
||||
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testAzureChatCompletionCall() {
|
||||
useAzureService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val conversationService = ConversationService.getInstance()
|
||||
val prevMessage = Message("TEST_PREV_PROMPT")
|
||||
prevMessage.response = "TEST_PREV_RESPONSE"
|
||||
val conversation = conversationService.startConversation()
|
||||
conversation.addMessage(prevMessage)
|
||||
conversationService.saveConversation(conversation)
|
||||
expectAzure(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo(
|
||||
"/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions")
|
||||
assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION")
|
||||
assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY")
|
||||
assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt")
|
||||
assertThat(request.body)
|
||||
.extracting("messages")
|
||||
.isEqualTo(
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"),
|
||||
mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")))
|
||||
listOf(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
val message = Message("TEST_PROMPT")
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
|
||||
requestHandler.call(CallParameters(conversation, message))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testLlamaChatCompletionCall() {
|
||||
useLlamaService()
|
||||
service<ConfigurationSettings>().state.maxTokens = 99
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
conversation.addMessage(Message("Ping", "Pong"))
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectLlama(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/completion")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"prompt",
|
||||
"n_predict",
|
||||
"stream")
|
||||
.containsExactly(
|
||||
LLAMA.buildPrompt(
|
||||
"TEST_SYSTEM_PROMPT",
|
||||
"TEST_PROMPT",
|
||||
conversation.messages),
|
||||
99,
|
||||
true)
|
||||
listOf<String?>(
|
||||
jsonMapResponse("content", "Hel"),
|
||||
jsonMapResponse("content", "lo!"),
|
||||
jsonMapResponse(
|
||||
e("content", ""),
|
||||
e("stop", true)))
|
||||
})
|
||||
|
||||
requestHandler.call(CallParameters(conversation, message))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testOllamaChatCompletionCall() {
|
||||
useOllamaService()
|
||||
service<ConfigurationSettings>().state.maxTokens = 99
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/api/chat")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages",
|
||||
"options.num_predict",
|
||||
"stream"
|
||||
)
|
||||
.containsExactly(
|
||||
HuggingFaceModel.LLAMA_3_8B_Q6_K.code,
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")
|
||||
),
|
||||
99,
|
||||
true
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))),
|
||||
e("done", false)),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "lo"), e("role", "assistant"))),
|
||||
e("done", false)),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "!"), e("role", "assistant"))),
|
||||
e("done", false)),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", ""), e("role", "assistant"))),
|
||||
e("done", true))
|
||||
)
|
||||
})
|
||||
|
||||
requestHandler.call(CallParameters(conversation, message))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testGoogleChatCompletionCall() {
|
||||
useGoogleService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectGoogle(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.uri.query).isEqualTo("key=TEST_API_KEY&alt=sse")
|
||||
assertThat(request.body)
|
||||
.extracting("contents")
|
||||
.isEqualTo(
|
||||
listOf(
|
||||
mapOf("parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")), "role" to "user"),
|
||||
mapOf("parts" to listOf(mapOf("text" to "Understood.")), "role" to "model"),
|
||||
mapOf("parts" to listOf(mapOf("text" to "TEST_PROMPT")), "role" to "user"),
|
||||
)
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse(
|
||||
"candidates",
|
||||
jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "Hello")))))
|
||||
),
|
||||
jsonMapResponse(
|
||||
"candidates",
|
||||
jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "!")))))
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
requestHandler.call(CallParameters(conversation, message))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testCodeGPTServiceChatCompletionCall() {
|
||||
useCodeGPTService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(
|
||||
getRequestEventListener(message)
|
||||
)
|
||||
expectCodeGPT(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"TEST_MODEL",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")))
|
||||
listOf(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
|
||||
requestHandler.call(CallParameters(conversation, message))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
private fun getRequestEventListener(message: Message): CompletionResponseEventListener {
|
||||
return object : CompletionResponseEventListener {
|
||||
override fun handleCompleted(fullMessage: String, callParameters: CallParameters) {
|
||||
message.response = fullMessage
|
||||
}
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testAzureChatCompletionCall() {
|
||||
useAzureService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val conversationService = ConversationService.getInstance()
|
||||
val prevMessage = Message("TEST_PREV_PROMPT")
|
||||
prevMessage.response = "TEST_PREV_RESPONSE"
|
||||
val conversation = conversationService.startConversation()
|
||||
conversation.addMessage(prevMessage)
|
||||
conversationService.saveConversation(conversation)
|
||||
expectAzure(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo(
|
||||
"/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions"
|
||||
)
|
||||
assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION")
|
||||
assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY")
|
||||
assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt")
|
||||
assertThat(request.body)
|
||||
.extracting("messages")
|
||||
.isEqualTo(
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"),
|
||||
mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")
|
||||
)
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse(
|
||||
"choices",
|
||||
jsonArray(jsonMap("delta", jsonMap("role", "assistant")))
|
||||
),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))
|
||||
)
|
||||
})
|
||||
val message = Message("TEST_PROMPT")
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
|
||||
|
||||
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testLlamaChatCompletionCall() {
|
||||
useLlamaService()
|
||||
service<ConfigurationSettings>().state.maxTokens = 99
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
conversation.addMessage(Message("Ping", "Pong"))
|
||||
expectLlama(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/completion")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"prompt",
|
||||
"n_predict",
|
||||
"stream"
|
||||
)
|
||||
.containsExactly(
|
||||
LLAMA.buildPrompt(
|
||||
"TEST_SYSTEM_PROMPT",
|
||||
"TEST_PROMPT",
|
||||
conversation.messages
|
||||
),
|
||||
99,
|
||||
true
|
||||
)
|
||||
listOf<String?>(
|
||||
jsonMapResponse("content", "Hel"),
|
||||
jsonMapResponse("content", "lo!"),
|
||||
jsonMapResponse(
|
||||
e("content", ""),
|
||||
e("stop", true)
|
||||
)
|
||||
)
|
||||
})
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
|
||||
|
||||
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testOllamaChatCompletionCall() {
|
||||
useOllamaService()
|
||||
service<ConfigurationSettings>().state.maxTokens = 99
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/api/chat")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages",
|
||||
"options.num_predict",
|
||||
"stream"
|
||||
)
|
||||
.containsExactly(
|
||||
HuggingFaceModel.LLAMA_3_8B_Q6_K.code,
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")
|
||||
),
|
||||
99,
|
||||
true
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))),
|
||||
e("done", false)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "lo"), e("role", "assistant"))),
|
||||
e("done", false)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "!"), e("role", "assistant"))),
|
||||
e("done", false)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", ""), e("role", "assistant"))),
|
||||
e("done", true)
|
||||
)
|
||||
)
|
||||
})
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
|
||||
|
||||
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testGoogleChatCompletionCall() {
|
||||
useGoogleService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
expectGoogle(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.uri.query).isEqualTo("key=TEST_API_KEY&alt=sse")
|
||||
assertThat(request.body)
|
||||
.extracting("contents")
|
||||
.isEqualTo(
|
||||
listOf(
|
||||
mapOf(
|
||||
"parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")),
|
||||
"role" to "user"
|
||||
),
|
||||
mapOf("parts" to listOf(mapOf("text" to "Understood.")), "role" to "model"),
|
||||
mapOf("parts" to listOf(mapOf("text" to "TEST_PROMPT")), "role" to "user"),
|
||||
)
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse(
|
||||
"candidates",
|
||||
jsonArray(
|
||||
jsonMap(
|
||||
"content",
|
||||
jsonMap("parts", jsonArray(jsonMap("text", "Hello")))
|
||||
)
|
||||
)
|
||||
),
|
||||
jsonMapResponse(
|
||||
"candidates",
|
||||
jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "!")))))
|
||||
)
|
||||
)
|
||||
})
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
|
||||
|
||||
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testCodeGPTServiceChatCompletionCall() {
|
||||
useCodeGPTService()
|
||||
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
expectCodeGPT(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages"
|
||||
)
|
||||
.containsExactly(
|
||||
"TEST_MODEL",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")
|
||||
)
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse(
|
||||
"choices",
|
||||
jsonArray(jsonMap("delta", jsonMap("role", "assistant")))
|
||||
),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))
|
||||
)
|
||||
})
|
||||
val requestHandler =
|
||||
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
|
||||
|
||||
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
private fun getRequestEventListener(message: Message): CompletionResponseEventListener {
|
||||
return object : CompletionResponseEventListener {
|
||||
override fun handleCompleted(
|
||||
fullMessage: String,
|
||||
callParameters: ChatCompletionParameters
|
||||
) {
|
||||
message.response = fullMessage
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
|
|||
)
|
||||
})
|
||||
|
||||
panel.sendMessage(message)
|
||||
panel.sendMessage(message, ConversationType.DEFAULT)
|
||||
|
||||
waitExpecting {
|
||||
val messages = conversation.messages
|
||||
|
|
@ -161,7 +161,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
|
|||
)
|
||||
})
|
||||
|
||||
panel.sendMessage(message)
|
||||
panel.sendMessage(message, ConversationType.DEFAULT)
|
||||
|
||||
waitExpecting {
|
||||
val messages = conversation.messages
|
||||
|
|
@ -250,7 +250,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
|
|||
)
|
||||
})
|
||||
|
||||
panel.sendMessage(message)
|
||||
panel.sendMessage(message, ConversationType.DEFAULT)
|
||||
|
||||
waitExpecting {
|
||||
val messages = conversation.messages
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue